mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #806 from alibaba-damo-academy/dev_wjm_sd
update eend-ola
This commit is contained in:
commit
993f226f35
@ -0,0 +1,45 @@
|
||||
# network architecture
|
||||
# encoder related
|
||||
encoder: eend_ola_transformer
|
||||
encoder_conf:
|
||||
idim: 345
|
||||
n_layers: 4
|
||||
n_units: 256
|
||||
|
||||
# encoder-decoder attractor related
|
||||
encoder_decoder_attractor: eda
|
||||
encoder_decoder_attractor_conf:
|
||||
n_units: 256
|
||||
|
||||
# model related
|
||||
model: eend_ola
|
||||
model_conf:
|
||||
attractor_loss_weight: 0.01
|
||||
max_n_speaker: 8
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 100
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- loss
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- loss
|
||||
- min
|
||||
keep_nbest_models: 100
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.00001
|
||||
|
||||
dataset_conf:
|
||||
data_names: speech_speaker_labels
|
||||
data_types: kaldi_ark
|
||||
batch_conf:
|
||||
batch_type: unsorted
|
||||
batch_size: 8
|
||||
num_workers: 8
|
||||
|
||||
log_interval: 50
|
||||
@ -0,0 +1,52 @@
|
||||
# network architecture
|
||||
# encoder related
|
||||
encoder: eend_ola_transformer
|
||||
encoder_conf:
|
||||
idim: 345
|
||||
n_layers: 4
|
||||
n_units: 256
|
||||
|
||||
# encoder-decoder attractor related
|
||||
encoder_decoder_attractor: eda
|
||||
encoder_decoder_attractor_conf:
|
||||
n_units: 256
|
||||
|
||||
# model related
|
||||
model: eend_ola
|
||||
model_conf:
|
||||
max_n_speaker: 8
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 100
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- loss
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- loss
|
||||
- min
|
||||
keep_nbest_models: 100
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 1.0
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.98
|
||||
eps: 1.0e-9
|
||||
scheduler: noamlr
|
||||
scheduler_conf:
|
||||
model_size: 256
|
||||
warmup_steps: 100000
|
||||
|
||||
dataset_conf:
|
||||
data_names: speech_speaker_labels
|
||||
data_types: kaldi_ark
|
||||
batch_conf:
|
||||
batch_type: unsorted
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
|
||||
log_interval: 50
|
||||
@ -0,0 +1,52 @@
|
||||
# network architecture
|
||||
# encoder related
|
||||
encoder: eend_ola_transformer
|
||||
encoder_conf:
|
||||
idim: 345
|
||||
n_layers: 4
|
||||
n_units: 256
|
||||
|
||||
# encoder-decoder attractor related
|
||||
encoder_decoder_attractor: eda
|
||||
encoder_decoder_attractor_conf:
|
||||
n_units: 256
|
||||
|
||||
# model related
|
||||
model: eend_ola
|
||||
model_conf:
|
||||
max_n_speaker: 8
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 25
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- loss
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- loss
|
||||
- min
|
||||
keep_nbest_models: 100
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 1.0
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.98
|
||||
eps: 1.0e-9
|
||||
scheduler: noamlr
|
||||
scheduler_conf:
|
||||
model_size: 256
|
||||
warmup_steps: 100000
|
||||
|
||||
dataset_conf:
|
||||
data_names: speech_speaker_labels
|
||||
data_types: kaldi_ark
|
||||
batch_conf:
|
||||
batch_type: unsorted
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
|
||||
log_interval: 50
|
||||
@ -0,0 +1,44 @@
|
||||
# network architecture
|
||||
# encoder related
|
||||
encoder: eend_ola_transformer
|
||||
encoder_conf:
|
||||
idim: 345
|
||||
n_layers: 4
|
||||
n_units: 256
|
||||
|
||||
# encoder-decoder attractor related
|
||||
encoder_decoder_attractor: eda
|
||||
encoder_decoder_attractor_conf:
|
||||
n_units: 256
|
||||
|
||||
# model related
|
||||
model: eend_ola
|
||||
model_conf:
|
||||
max_n_speaker: 8
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 1
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- loss
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- loss
|
||||
- min
|
||||
keep_nbest_models: 100
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.00001
|
||||
|
||||
dataset_conf:
|
||||
data_names: speech_speaker_labels
|
||||
data_types: kaldi_ark
|
||||
batch_conf:
|
||||
batch_type: unsorted
|
||||
batch_size: 8
|
||||
num_workers: 8
|
||||
|
||||
log_interval: 50
|
||||
144
egs/callhome/eend_ola/local/dump_feature.py
Normal file
144
egs/callhome/eend_ola/local/dump_feature.py
Normal file
@ -0,0 +1,144 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from kaldiio import WriteHelper
|
||||
|
||||
import funasr.modules.eend_ola.utils.feature as feature
|
||||
from funasr.modules.eend_ola.utils.kaldi_data import load_segments_rechash, load_utt2spk, load_wav_scp, load_reco2dur, \
|
||||
load_spk2utt, load_wav
|
||||
|
||||
|
||||
def _count_frames(data_len, size, step):
|
||||
return int((data_len - size + step) / step)
|
||||
|
||||
|
||||
def _gen_frame_indices(
|
||||
data_length, size=2000, step=2000,
|
||||
use_last_samples=False,
|
||||
label_delay=0,
|
||||
subsampling=1):
|
||||
i = -1
|
||||
for i in range(_count_frames(data_length, size, step)):
|
||||
yield i * step, i * step + size
|
||||
if use_last_samples and i * step + size < data_length:
|
||||
if data_length - (i + 1) * step - subsampling * label_delay > 0:
|
||||
yield (i + 1) * step, data_length
|
||||
|
||||
|
||||
class KaldiData:
|
||||
def __init__(self, data_dir, idx):
|
||||
self.data_dir = data_dir
|
||||
segment_file = os.path.join(self.data_dir, 'segments.{}'.format(idx))
|
||||
self.segments = load_segments_rechash(segment_file)
|
||||
|
||||
utt2spk_file = os.path.join(self.data_dir, 'utt2spk.{}'.format(idx))
|
||||
self.utt2spk = load_utt2spk(utt2spk_file)
|
||||
|
||||
wav_file = os.path.join(self.data_dir, 'wav.scp.{}'.format(idx))
|
||||
self.wavs = load_wav_scp(wav_file)
|
||||
|
||||
reco2dur_file = os.path.join(self.data_dir, 'reco2dur.{}'.format(idx))
|
||||
self.reco2dur = load_reco2dur(reco2dur_file)
|
||||
|
||||
spk2utt_file = os.path.join(self.data_dir, 'spk2utt.{}'.format(idx))
|
||||
self.spk2utt = load_spk2utt(spk2utt_file)
|
||||
|
||||
def load_wav(self, recid, start=0, end=None):
|
||||
data, rate = load_wav(self.wavs[recid], start, end)
|
||||
return data, rate
|
||||
|
||||
|
||||
class KaldiDiarizationDataset():
|
||||
def __init__(
|
||||
self,
|
||||
data_dir,
|
||||
index,
|
||||
chunk_size=2000,
|
||||
context_size=0,
|
||||
frame_size=1024,
|
||||
frame_shift=256,
|
||||
subsampling=1,
|
||||
rate=16000,
|
||||
input_transform=None,
|
||||
use_last_samples=False,
|
||||
label_delay=0,
|
||||
n_speakers=None,
|
||||
):
|
||||
self.data_dir = data_dir
|
||||
self.index = index
|
||||
self.chunk_size = chunk_size
|
||||
self.context_size = context_size
|
||||
self.frame_size = frame_size
|
||||
self.frame_shift = frame_shift
|
||||
self.subsampling = subsampling
|
||||
self.input_transform = input_transform
|
||||
self.n_speakers = n_speakers
|
||||
self.chunk_indices = []
|
||||
self.label_delay = label_delay
|
||||
|
||||
self.data = KaldiData(self.data_dir, index)
|
||||
|
||||
for rec, path in self.data.wavs.items():
|
||||
data_len = int(self.data.reco2dur[rec] * rate / frame_shift)
|
||||
data_len = int(data_len / self.subsampling)
|
||||
for st, ed in _gen_frame_indices(
|
||||
data_len, chunk_size, chunk_size, use_last_samples,
|
||||
label_delay=self.label_delay,
|
||||
subsampling=self.subsampling):
|
||||
self.chunk_indices.append(
|
||||
(rec, path, st * self.subsampling, ed * self.subsampling))
|
||||
print(len(self.chunk_indices), " chunks")
|
||||
|
||||
|
||||
def convert(args):
|
||||
dataset = KaldiDiarizationDataset(
|
||||
data_dir=args.data_dir,
|
||||
index=args.index,
|
||||
chunk_size=args.num_frames,
|
||||
context_size=args.context_size,
|
||||
input_transform="logmel23_mn",
|
||||
frame_size=args.frame_size,
|
||||
frame_shift=args.frame_shift,
|
||||
subsampling=args.subsampling,
|
||||
rate=8000,
|
||||
use_last_samples=True,
|
||||
)
|
||||
|
||||
feature_ark_file = os.path.join(args.output_dir, "feature.ark.{}".format(args.index))
|
||||
feature_scp_file = os.path.join(args.output_dir, "feature.scp.{}".format(args.index))
|
||||
label_ark_file = os.path.join(args.output_dir, "label.ark.{}".format(args.index))
|
||||
label_scp_file = os.path.join(args.output_dir, "label.scp.{}".format(args.index))
|
||||
with WriteHelper('ark,scp:{},{}'.format(feature_ark_file, feature_scp_file)) as feature_writer, \
|
||||
WriteHelper('ark,scp:{},{}'.format(label_ark_file, label_scp_file)) as label_writer:
|
||||
for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
|
||||
Y, T = feature.get_labeledSTFT(
|
||||
dataset.data,
|
||||
rec,
|
||||
st,
|
||||
ed,
|
||||
dataset.frame_size,
|
||||
dataset.frame_shift,
|
||||
dataset.n_speakers)
|
||||
Y = feature.transform(Y, dataset.input_transform)
|
||||
Y_spliced = feature.splice(Y, dataset.context_size)
|
||||
Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
|
||||
st = '{:0>7d}'.format(st)
|
||||
ed = '{:0>7d}'.format(ed)
|
||||
key = "{}_{}_{}".format(rec, st, ed)
|
||||
feature_writer(key, Y_ss)
|
||||
label_writer(key, T_ss.reshape(-1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_dir", type=str)
|
||||
parser.add_argument("--output_dir", type=str)
|
||||
parser.add_argument("--index", type=str)
|
||||
parser.add_argument("--num_frames", type=int, default=500)
|
||||
parser.add_argument("--context_size", type=int, default=7)
|
||||
parser.add_argument("--frame_size", type=int, default=200)
|
||||
parser.add_argument("--frame_shift", type=int, default=80)
|
||||
parser.add_argument("--subsampling", type=int, default=10)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
25
egs/callhome/eend_ola/local/gen_feats_scp.py
Normal file
25
egs/callhome/eend_ola/local/gen_feats_scp.py
Normal file
@ -0,0 +1,25 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--root_path", type=str)
|
||||
parser.add_argument("--out_path", type=str)
|
||||
parser.add_argument("--split_num", type=int, default=64)
|
||||
args = parser.parse_args()
|
||||
root_path = args.root_path
|
||||
out_path = args.out_path
|
||||
split_num = args.split_num
|
||||
|
||||
with open(os.path.join(out_path, "feats.scp"), "w") as out_f:
|
||||
for i in range(split_num):
|
||||
idx = str(i + 1)
|
||||
feature_file = os.path.join(root_path, "feature.scp.{}".format(idx))
|
||||
label_file = os.path.join(root_path, "label.scp.{}".format(idx))
|
||||
with open(feature_file) as ff, open(label_file) as fl:
|
||||
ff_lines = ff.readlines()
|
||||
fl_lines = fl.readlines()
|
||||
for ff_line, fl_line in zip(ff_lines, fl_lines):
|
||||
sample_name, f_path = ff_line.strip().split()
|
||||
_, l_path = fl_line.strip().split()
|
||||
out_f.write("{} {} {}\n".format(sample_name, f_path, l_path))
|
||||
138
egs/callhome/eend_ola/local/infer.py
Normal file
138
egs/callhome/eend_ola/local/infer.py
Normal file
@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import yaml
|
||||
from scipy.signal import medfilt
|
||||
|
||||
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
type=str,
|
||||
help="model config file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_file",
|
||||
type=str,
|
||||
help="model path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_rttm_file",
|
||||
type=str,
|
||||
help="output rttm path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wav_scp_file",
|
||||
type=str,
|
||||
default="wav.scp",
|
||||
help="input data path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_shift",
|
||||
type=int,
|
||||
default=80,
|
||||
help="frame shift",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_size",
|
||||
type=int,
|
||||
default=200,
|
||||
help="frame size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_size",
|
||||
type=int,
|
||||
default=7,
|
||||
help="context size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling_rate",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="sampling rate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subsampling",
|
||||
type=int,
|
||||
default=10,
|
||||
help="setting subsampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="shuffle speech in time",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attractor_threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="threshold for selecting attractors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_file) as f:
|
||||
configs = yaml.safe_load(f)
|
||||
for k, v in configs.items():
|
||||
if not hasattr(args, k):
|
||||
setattr(args, k, v)
|
||||
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
os.environ['PYTORCH_SEED'] = str(args.seed)
|
||||
|
||||
model, _ = build_model_from_file(config_file=args.config_file, model_file=args.model_file, task_name="diar",
|
||||
device=args.device)
|
||||
model.eval()
|
||||
|
||||
with open(args.wav_scp_file) as f:
|
||||
wav_lines = [line.strip().split() for line in f.readlines()]
|
||||
wav_items = {x[0]: x[1] for x in wav_lines}
|
||||
|
||||
print("Start inference")
|
||||
with open(args.output_rttm_file, "w") as wf:
|
||||
for wav_id in wav_items.keys():
|
||||
print("Process wav: {}".format(wav_id))
|
||||
data, rate = sf.read(wav_items[wav_id])
|
||||
speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift)
|
||||
speech = eend_ola_feature.transform(speech)
|
||||
speech = eend_ola_feature.splice(speech, context_size=args.context_size)
|
||||
speech = speech[::args.subsampling] # sampling
|
||||
speech = torch.from_numpy(speech)
|
||||
|
||||
with torch.no_grad():
|
||||
speech = speech.to(args.device)
|
||||
ys, _, _, _ = model.estimate_sequential(
|
||||
[speech],
|
||||
n_speakers=None,
|
||||
th=args.attractor_threshold,
|
||||
shuffle=args.shuffle
|
||||
)
|
||||
|
||||
a = ys[0].cpu().numpy()
|
||||
a = medfilt(a, (11, 1))
|
||||
rst = []
|
||||
for spkr_id, frames in enumerate(a.T):
|
||||
frames = np.pad(frames, (1, 1), 'constant')
|
||||
changes, = np.where(np.diff(frames, axis=0) != 0)
|
||||
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
|
||||
for s, e in zip(changes[::2], changes[1::2]):
|
||||
st = s * args.frame_shift * args.subsampling / args.sampling_rate
|
||||
dur = (e - s) * args.frame_shift * args.subsampling / args.sampling_rate
|
||||
print(fmt.format(
|
||||
wav_id,
|
||||
st,
|
||||
dur,
|
||||
wav_id + "_" + str(spkr_id)), file=wf)
|
||||
73
egs/callhome/eend_ola/local/make_callhome.sh
Executable file
73
egs/callhome/eend_ola/local/make_callhome.sh
Executable file
@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2017 David Snyder
|
||||
# Apache 2.0.
|
||||
#
|
||||
# This script prepares the Callhome portion of the NIST SRE 2000
|
||||
# corpus (LDC2001S97). It is the evaluation dataset used in the
|
||||
# callhome_diarization recipe.
|
||||
|
||||
if [ $# -ne 2 ]; then
|
||||
echo "Usage: $0 <callhome-speech> <out-data-dir>"
|
||||
echo "e.g.: $0 /mnt/data/LDC2001S97 data/"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
src_dir=$1
|
||||
data_dir=$2
|
||||
|
||||
tmp_dir=$data_dir/callhome/.tmp/
|
||||
mkdir -p $tmp_dir
|
||||
|
||||
# Download some metadata that wasn't provided in the LDC release
|
||||
if [ ! -d "$tmp_dir/sre2000-key" ]; then
|
||||
wget --no-check-certificate -P $tmp_dir/ \
|
||||
http://www.openslr.org/resources/10/sre2000-key.tar.gz
|
||||
tar -xvf $tmp_dir/sre2000-key.tar.gz -C $tmp_dir/
|
||||
fi
|
||||
|
||||
# The list of 500 recordings
|
||||
awk '{print $1}' $tmp_dir/sre2000-key/reco2num > $tmp_dir/reco.list
|
||||
|
||||
# Create wav.scp file
|
||||
count=0
|
||||
missing=0
|
||||
while read reco; do
|
||||
path=$(find $src_dir -name "$reco.sph")
|
||||
if [ -z "${path// }" ]; then
|
||||
>&2 echo "$0: Missing Sphere file for $reco"
|
||||
missing=$((missing+1))
|
||||
else
|
||||
echo "$reco sph2pipe -f wav -p $path |"
|
||||
fi
|
||||
count=$((count+1))
|
||||
done < $tmp_dir/reco.list > $data_dir/callhome/wav.scp
|
||||
|
||||
if [ $missing -gt 0 ]; then
|
||||
echo "$0: Missing $missing out of $count recordings"
|
||||
fi
|
||||
|
||||
cp $tmp_dir/sre2000-key/segments $data_dir/callhome/
|
||||
awk '{print $1, $2}' $data_dir/callhome/segments > $data_dir/callhome/utt2spk
|
||||
utils/utt2spk_to_spk2utt.pl $data_dir/callhome/utt2spk > $data_dir/callhome/spk2utt
|
||||
cp $tmp_dir/sre2000-key/reco2num $data_dir/callhome/reco2num_spk
|
||||
cp $tmp_dir/sre2000-key/fullref.rttm $data_dir/callhome/
|
||||
|
||||
utils/validate_data_dir.sh --no-text --no-feats $data_dir/callhome
|
||||
utils/fix_data_dir.sh $data_dir/callhome
|
||||
|
||||
utils/copy_data_dir.sh $data_dir/callhome $data_dir/callhome1
|
||||
utils/copy_data_dir.sh $data_dir/callhome $data_dir/callhome2
|
||||
|
||||
utils/shuffle_list.pl $data_dir/callhome/wav.scp | head -n 250 \
|
||||
| utils/filter_scp.pl - $data_dir/callhome/wav.scp \
|
||||
> $data_dir/callhome1/wav.scp
|
||||
utils/fix_data_dir.sh $data_dir/callhome1
|
||||
utils/filter_scp.pl --exclude $data_dir/callhome1/wav.scp \
|
||||
$data_dir/callhome/wav.scp > $data_dir/callhome2/wav.scp
|
||||
utils/fix_data_dir.sh $data_dir/callhome2
|
||||
utils/filter_scp.pl $data_dir/callhome1/wav.scp $data_dir/callhome/reco2num_spk \
|
||||
> $data_dir/callhome1/reco2num_spk
|
||||
utils/filter_scp.pl $data_dir/callhome2/wav.scp $data_dir/callhome/reco2num_spk \
|
||||
> $data_dir/callhome2/reco2num_spk
|
||||
|
||||
rm -rf $tmp_dir 2> /dev/null
|
||||
120
egs/callhome/eend_ola/local/make_mixture.py
Executable file
120
egs/callhome/eend_ola/local/make_mixture.py
Executable file
@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# This script generates simulated multi-talker mixtures for diarization
|
||||
#
|
||||
# common/make_mixture.py \
|
||||
# mixture.scp \
|
||||
# data/mixture \
|
||||
# wav/mixture
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from funasr.modules.eend_ola.utils import kaldi_data
|
||||
import numpy as np
|
||||
import math
|
||||
import soundfile as sf
|
||||
import json
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('script',
|
||||
help='list of json')
|
||||
parser.add_argument('out_data_dir',
|
||||
help='output data dir of mixture')
|
||||
parser.add_argument('out_wav_dir',
|
||||
help='output mixture wav files are stored here')
|
||||
parser.add_argument('--rate', type=int, default=16000,
|
||||
help='sampling rate')
|
||||
args = parser.parse_args()
|
||||
|
||||
# open output data files
|
||||
segments_f = open(args.out_data_dir + '/segments', 'w')
|
||||
utt2spk_f = open(args.out_data_dir + '/utt2spk', 'w')
|
||||
wav_scp_f = open(args.out_data_dir + '/wav.scp', 'w')
|
||||
|
||||
# "-R" forces the default random seed for reproducibility
|
||||
resample_cmd = "sox -R -t wav - -t wav - rate {}".format(args.rate)
|
||||
|
||||
for line in open(args.script):
|
||||
recid, jsonstr = line.strip().split(None, 1)
|
||||
indata = json.loads(jsonstr)
|
||||
wavfn = indata['recid']
|
||||
# recid now include out_wav_dir
|
||||
recid = os.path.join(args.out_wav_dir, wavfn).replace('/','_')
|
||||
noise = indata['noise']
|
||||
noise_snr = indata['snr']
|
||||
mixture = []
|
||||
for speaker in indata['speakers']:
|
||||
spkid = speaker['spkid']
|
||||
utts = speaker['utts']
|
||||
intervals = speaker['intervals']
|
||||
rir = speaker['rir']
|
||||
data = []
|
||||
pos = 0
|
||||
for interval, utt in zip(intervals, utts):
|
||||
# append silence interval data
|
||||
silence = np.zeros(int(interval * args.rate))
|
||||
data.append(silence)
|
||||
# utterance is reverberated using room impulse response
|
||||
preprocess = "wav-reverberate --print-args=false " \
|
||||
" --impulse-response={} - -".format(rir)
|
||||
if isinstance(utt, list):
|
||||
rec, st, et = utt
|
||||
st = np.rint(st * args.rate).astype(int)
|
||||
et = np.rint(et * args.rate).astype(int)
|
||||
else:
|
||||
rec = utt
|
||||
st = 0
|
||||
et = None
|
||||
if rir is not None:
|
||||
wav_rxfilename = kaldi_data.process_wav(rec, preprocess)
|
||||
else:
|
||||
wav_rxfilename = rec
|
||||
wav_rxfilename = kaldi_data.process_wav(
|
||||
wav_rxfilename, resample_cmd)
|
||||
speech, _ = kaldi_data.load_wav(wav_rxfilename, st, et)
|
||||
data.append(speech)
|
||||
# calculate start/end position in samples
|
||||
startpos = pos + len(silence)
|
||||
endpos = startpos + len(speech)
|
||||
# write segments and utt2spk
|
||||
uttid = '{}_{}_{:07d}_{:07d}'.format(
|
||||
spkid, recid, int(startpos / args.rate * 100),
|
||||
int(endpos / args.rate * 100))
|
||||
print(uttid, recid,
|
||||
startpos / args.rate, endpos / args.rate, file=segments_f)
|
||||
print(uttid, spkid, file=utt2spk_f)
|
||||
# update position for next utterance
|
||||
pos = endpos
|
||||
data = np.concatenate(data)
|
||||
mixture.append(data)
|
||||
|
||||
# fitting to the maximum-length speaker data, then mix all speakers
|
||||
maxlen = max(len(x) for x in mixture)
|
||||
mixture = [np.pad(x, (0, maxlen - len(x)), 'constant') for x in mixture]
|
||||
mixture = np.sum(mixture, axis=0)
|
||||
# noise is repeated or cutted for fitting to the mixture data length
|
||||
noise_resampled = kaldi_data.process_wav(noise, resample_cmd)
|
||||
noise_data, _ = kaldi_data.load_wav(noise_resampled)
|
||||
if maxlen > len(noise_data):
|
||||
noise_data = np.pad(noise_data, (0, maxlen - len(noise_data)), 'wrap')
|
||||
else:
|
||||
noise_data = noise_data[:maxlen]
|
||||
# noise power is scaled according to selected SNR, then mixed
|
||||
signal_power = np.sum(mixture**2) / len(mixture)
|
||||
noise_power = np.sum(noise_data**2) / len(noise_data)
|
||||
scale = math.sqrt(
|
||||
math.pow(10, - noise_snr / 10) * signal_power / noise_power)
|
||||
mixture += noise_data * scale
|
||||
# output the wav file and write wav.scp
|
||||
outfname = '{}.wav'.format(wavfn)
|
||||
outpath = os.path.join(args.out_wav_dir, outfname)
|
||||
sf.write(outpath, mixture, args.rate)
|
||||
print(recid, os.path.abspath(outpath), file=wav_scp_f)
|
||||
|
||||
wav_scp_f.close()
|
||||
segments_f.close()
|
||||
utt2spk_f.close()
|
||||
123
egs/callhome/eend_ola/local/make_musan.py
Executable file
123
egs/callhome/eend_ola/local/make_musan.py
Executable file
@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2015 David Snyder
|
||||
# 2018 Ewald Enzinger
|
||||
# Apache 2.0.
|
||||
#
|
||||
# Modified version of egs/sre16/v1/local/make_musan.py (commit e3fb7c4a0da4167f8c94b80f4d3cc5ab4d0e22e8).
|
||||
# This version uses the raw MUSAN audio files (16 kHz) and does not use sox to resample at 8 kHz.
|
||||
#
|
||||
# This file is meant to be invoked by make_musan.sh.
|
||||
|
||||
import os, sys
|
||||
|
||||
def process_music_annotations(path):
|
||||
utt2spk = {}
|
||||
utt2vocals = {}
|
||||
lines = open(path, 'r').readlines()
|
||||
for line in lines:
|
||||
utt, genres, vocals, musician = line.rstrip().split()[:4]
|
||||
# For this application, the musican ID isn't important
|
||||
utt2spk[utt] = utt
|
||||
utt2vocals[utt] = vocals == "Y"
|
||||
return utt2spk, utt2vocals
|
||||
|
||||
def prepare_music(root_dir, use_vocals):
|
||||
utt2vocals = {}
|
||||
utt2spk = {}
|
||||
utt2wav = {}
|
||||
num_good_files = 0
|
||||
num_bad_files = 0
|
||||
music_dir = os.path.join(root_dir, "music")
|
||||
for root, dirs, files in os.walk(music_dir):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
if file.endswith(".wav"):
|
||||
utt = str(file).replace(".wav", "")
|
||||
utt2wav[utt] = file_path
|
||||
elif str(file) == "ANNOTATIONS":
|
||||
utt2spk_part, utt2vocals_part = process_music_annotations(file_path)
|
||||
utt2spk.update(utt2spk_part)
|
||||
utt2vocals.update(utt2vocals_part)
|
||||
utt2spk_str = ""
|
||||
utt2wav_str = ""
|
||||
for utt in utt2vocals:
|
||||
if utt in utt2wav:
|
||||
if use_vocals or not utt2vocals[utt]:
|
||||
utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n"
|
||||
utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n"
|
||||
num_good_files += 1
|
||||
else:
|
||||
print("Missing file {}".format(utt))
|
||||
num_bad_files += 1
|
||||
print("In music directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files))
|
||||
return utt2spk_str, utt2wav_str
|
||||
|
||||
def prepare_speech(root_dir):
|
||||
utt2spk = {}
|
||||
utt2wav = {}
|
||||
num_good_files = 0
|
||||
num_bad_files = 0
|
||||
speech_dir = os.path.join(root_dir, "speech")
|
||||
for root, dirs, files in os.walk(speech_dir):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
if file.endswith(".wav"):
|
||||
utt = str(file).replace(".wav", "")
|
||||
utt2wav[utt] = file_path
|
||||
utt2spk[utt] = utt
|
||||
utt2spk_str = ""
|
||||
utt2wav_str = ""
|
||||
for utt in utt2spk:
|
||||
if utt in utt2wav:
|
||||
utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n"
|
||||
utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n"
|
||||
num_good_files += 1
|
||||
else:
|
||||
print("Missing file {}".format(utt))
|
||||
num_bad_files += 1
|
||||
print("In speech directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files))
|
||||
return utt2spk_str, utt2wav_str
|
||||
|
||||
def prepare_noise(root_dir):
|
||||
utt2spk = {}
|
||||
utt2wav = {}
|
||||
num_good_files = 0
|
||||
num_bad_files = 0
|
||||
noise_dir = os.path.join(root_dir, "noise")
|
||||
for root, dirs, files in os.walk(noise_dir):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
if file.endswith(".wav"):
|
||||
utt = str(file).replace(".wav", "")
|
||||
utt2wav[utt] = file_path
|
||||
utt2spk[utt] = utt
|
||||
utt2spk_str = ""
|
||||
utt2wav_str = ""
|
||||
for utt in utt2spk:
|
||||
if utt in utt2wav:
|
||||
utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n"
|
||||
utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n"
|
||||
num_good_files += 1
|
||||
else:
|
||||
print("Missing file {}".format(utt))
|
||||
num_bad_files += 1
|
||||
print("In noise directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files))
|
||||
return utt2spk_str, utt2wav_str
|
||||
|
||||
def main():
|
||||
in_dir = sys.argv[1]
|
||||
out_dir = sys.argv[2]
|
||||
use_vocals = sys.argv[3] == "Y"
|
||||
utt2spk_music, utt2wav_music = prepare_music(in_dir, use_vocals)
|
||||
utt2spk_speech, utt2wav_speech = prepare_speech(in_dir)
|
||||
utt2spk_noise, utt2wav_noise = prepare_noise(in_dir)
|
||||
utt2spk = utt2spk_speech + utt2spk_music + utt2spk_noise
|
||||
utt2wav = utt2wav_speech + utt2wav_music + utt2wav_noise
|
||||
wav_fi = open(os.path.join(out_dir, "wav.scp"), 'w')
|
||||
wav_fi.write(utt2wav)
|
||||
utt2spk_fi = open(os.path.join(out_dir, "utt2spk"), 'w')
|
||||
utt2spk_fi.write(utt2spk)
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
37
egs/callhome/eend_ola/local/make_musan.sh
Executable file
37
egs/callhome/eend_ola/local/make_musan.sh
Executable file
@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2015 David Snyder
|
||||
# Apache 2.0.
|
||||
#
|
||||
# This script, called by ../run.sh, creates the MUSAN
|
||||
# data directory. The required dataset is freely available at
|
||||
# http://www.openslr.org/17/
|
||||
|
||||
set -e
|
||||
in_dir=$1
|
||||
data_dir=$2
|
||||
use_vocals='Y'
|
||||
|
||||
mkdir -p local/musan.tmp
|
||||
|
||||
echo "Preparing ${data_dir}/musan..."
|
||||
mkdir -p ${data_dir}/musan
|
||||
local/make_musan.py ${in_dir} ${data_dir}/musan ${use_vocals}
|
||||
|
||||
utils/fix_data_dir.sh ${data_dir}/musan
|
||||
|
||||
grep "music" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_music
|
||||
grep "speech" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_speech
|
||||
grep "noise" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_noise
|
||||
utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_music \
|
||||
${data_dir}/musan ${data_dir}/musan_music
|
||||
utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_speech \
|
||||
${data_dir}/musan ${data_dir}/musan_speech
|
||||
utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_noise \
|
||||
${data_dir}/musan ${data_dir}/musan_noise
|
||||
|
||||
utils/fix_data_dir.sh ${data_dir}/musan_music
|
||||
utils/fix_data_dir.sh ${data_dir}/musan_speech
|
||||
utils/fix_data_dir.sh ${data_dir}/musan_noise
|
||||
|
||||
rm -rf local/musan.tmp
|
||||
|
||||
63
egs/callhome/eend_ola/local/make_sre.pl
Executable file
63
egs/callhome/eend_ola/local/make_sre.pl
Executable file
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/perl
|
||||
#
|
||||
# Copyright 2015 David Snyder
|
||||
# Apache 2.0.
|
||||
# Usage: make_sre.pl <path-to-data> <name-of-source> <sre-ref> <output-dir>
|
||||
|
||||
if (@ARGV != 4) {
|
||||
print STDERR "Usage: $0 <path-to-data> <name-of-source> <sre-ref> <output-dir>\n";
|
||||
print STDERR "e.g. $0 /export/corpora5/LDC/LDC2006S44 sre2004 sre_ref data/sre2004\n";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
($db_base, $sre_name, $sre_ref_filename, $out_dir) = @ARGV;
|
||||
%utt2sph = ();
|
||||
%spk2gender = ();
|
||||
|
||||
$tmp_dir = "$out_dir/tmp";
|
||||
if (system("mkdir -p $tmp_dir") != 0) {
|
||||
die "Error making directory $tmp_dir";
|
||||
}
|
||||
|
||||
if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) {
|
||||
die "Error getting list of sph files";
|
||||
}
|
||||
open(WAVLIST, "<", "$tmp_dir/sph.list") or die "cannot open wav list";
|
||||
|
||||
while(<WAVLIST>) {
|
||||
chomp;
|
||||
$sph = $_;
|
||||
@A1 = split("/",$sph);
|
||||
@A2 = split("[./]",$A1[$#A1]);
|
||||
$uttId=$A2[0];
|
||||
$utt2sph{$uttId} = $sph;
|
||||
}
|
||||
|
||||
open(GNDR,">", "$out_dir/spk2gender") or die "Could not open the output file $out_dir/spk2gender";
|
||||
open(SPKR,">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk";
|
||||
open(WAV,">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp";
|
||||
open(SRE_REF, "<", $sre_ref_filename) or die "Cannot open SRE reference.";
|
||||
while (<SRE_REF>) {
|
||||
chomp;
|
||||
($speaker, $gender, $other_sre_name, $utt_id, $channel) = split(" ", $_);
|
||||
$channel_num = "1";
|
||||
if ($channel eq "A") {
|
||||
$channel_num = "1";
|
||||
} else {
|
||||
$channel_num = "2";
|
||||
}
|
||||
if (($other_sre_name eq $sre_name) and (exists $utt2sph{$utt_id})) {
|
||||
$full_utt_id = "$speaker-$gender-$sre_name-$utt_id-$channel";
|
||||
$spk2gender{"$speaker-$gender"} = $gender;
|
||||
print WAV "$full_utt_id"," sph2pipe -f wav -p -c $channel_num $utt2sph{$utt_id} |\n";
|
||||
print SPKR "$full_utt_id $speaker-$gender","\n";
|
||||
}
|
||||
}
|
||||
foreach $speaker (keys %spk2gender) {
|
||||
print GNDR "$speaker $spk2gender{$speaker}\n";
|
||||
}
|
||||
|
||||
close(GNDR) || die;
|
||||
close(SPKR) || die;
|
||||
close(WAV) || die;
|
||||
close(SRE_REF) || die;
|
||||
48
egs/callhome/eend_ola/local/make_sre.sh
Executable file
48
egs/callhome/eend_ola/local/make_sre.sh
Executable file
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2015 David Snyder
|
||||
# Apache 2.0.
|
||||
#
|
||||
# See README.txt for more info on data required.
|
||||
|
||||
set -e
|
||||
|
||||
data_root=$1
|
||||
data_dir=$2
|
||||
|
||||
wget -P data/local/ http://www.openslr.org/resources/15/speaker_list.tgz
|
||||
tar -C data/local/ -xvf data/local/speaker_list.tgz
|
||||
sre_ref=data/local/speaker_list
|
||||
|
||||
local/make_sre.pl $data_root/LDC2006S44/ \
|
||||
sre2004 $sre_ref $data_dir/sre2004
|
||||
|
||||
local/make_sre.pl $data_root/LDC2011S01 \
|
||||
sre2005 $sre_ref $data_dir/sre2005_train
|
||||
|
||||
local/make_sre.pl $data_root/LDC2011S04 \
|
||||
sre2005 $sre_ref $data_dir/sre2005_test
|
||||
|
||||
local/make_sre.pl $data_root/LDC2011S09 \
|
||||
sre2006 $sre_ref $data_dir/sre2006_train
|
||||
|
||||
local/make_sre.pl $data_root/LDC2011S10 \
|
||||
sre2006 $sre_ref $data_dir/sre2006_test_1
|
||||
|
||||
local/make_sre.pl $data_root/LDC2012S01 \
|
||||
sre2006 $sre_ref $data_dir/sre2006_test_2
|
||||
|
||||
local/make_sre.pl $data_root/LDC2011S05 \
|
||||
sre2008 $sre_ref $data_dir/sre2008_train
|
||||
|
||||
local/make_sre.pl $data_root/LDC2011S08 \
|
||||
sre2008 $sre_ref $data_dir/sre2008_test
|
||||
|
||||
utils/combine_data.sh $data_dir/sre \
|
||||
$data_dir/sre2004 $data_dir/sre2005_train \
|
||||
$data_dir/sre2005_test $data_dir/sre2006_train \
|
||||
$data_dir/sre2006_test_1 $data_dir/sre2006_test_2 \
|
||||
$data_dir/sre2008_train $data_dir/sre2008_test
|
||||
|
||||
utils/validate_data_dir.sh --no-text --no-feats $data_dir/sre
|
||||
utils/fix_data_dir.sh $data_dir/sre
|
||||
rm data/local/speaker_list.*
|
||||
106
egs/callhome/eend_ola/local/make_swbd2_phase1.pl
Executable file
106
egs/callhome/eend_ola/local/make_swbd2_phase1.pl
Executable file
@ -0,0 +1,106 @@
|
||||
#!/usr/bin/perl
|
||||
use warnings; #sed replacement for -w perl parameter
|
||||
#
|
||||
# Copyright 2017 David Snyder
|
||||
# Apache 2.0
|
||||
|
||||
if (@ARGV != 2) {
|
||||
print STDERR "Usage: $0 <path-to-LDC98S75> <path-to-output>\n";
|
||||
print STDERR "e.g. $0 /export/corpora3/LDC/LDC98S75 data/swbd2_phase1_train\n";
|
||||
exit(1);
|
||||
}
|
||||
($db_base, $out_dir) = @ARGV;
|
||||
|
||||
if (system("mkdir -p $out_dir")) {
|
||||
die "Error making directory $out_dir";
|
||||
}
|
||||
|
||||
open(CS, "<$db_base/doc/callstat.tbl") || die "Could not open $db_base/doc/callstat.tbl";
|
||||
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
|
||||
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
|
||||
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
|
||||
|
||||
@badAudio = ("3", "4");
|
||||
|
||||
$tmp_dir = "$out_dir/tmp";
|
||||
if (system("mkdir -p $tmp_dir") != 0) {
|
||||
die "Error making directory $tmp_dir";
|
||||
}
|
||||
|
||||
if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) {
|
||||
die "Error getting list of sph files";
|
||||
}
|
||||
|
||||
open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list";
|
||||
|
||||
%wavs = ();
|
||||
while(<WAVLIST>) {
|
||||
chomp;
|
||||
$sph = $_;
|
||||
@t = split("/",$sph);
|
||||
@t1 = split("[./]",$t[$#t]);
|
||||
$uttId = $t1[0];
|
||||
$wavs{$uttId} = $sph;
|
||||
}
|
||||
|
||||
while (<CS>) {
|
||||
$line = $_ ;
|
||||
@A = split(",", $line);
|
||||
@A1 = split("[./]",$A[0]);
|
||||
$wav = $A1[0];
|
||||
if (/$wav/i ~~ @badAudio) {
|
||||
# do nothing
|
||||
print "Bad Audio = $wav";
|
||||
} else {
|
||||
$spkr1= "sw_" . $A[2];
|
||||
$spkr2= "sw_" . $A[3];
|
||||
$gender1 = $A[5];
|
||||
$gender2 = $A[6];
|
||||
if ($gender1 eq "M") {
|
||||
$gender1 = "m";
|
||||
} elsif ($gender1 eq "F") {
|
||||
$gender1 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if ($gender2 eq "M") {
|
||||
$gender2 = "m";
|
||||
} elsif ($gender2 eq "F") {
|
||||
$gender2 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if (-e "$wavs{$wav}") {
|
||||
$uttId = $spkr1 ."_" . $wav ."_1";
|
||||
if (!$spk2gender{$spkr1}) {
|
||||
$spk2gender{$spkr1} = $gender1;
|
||||
print GNDR "$spkr1"," $gender1\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wavs{$wav} |\n";
|
||||
print SPKR "$uttId"," $spkr1","\n";
|
||||
|
||||
$uttId = $spkr2 . "_" . $wav ."_2";
|
||||
if (!$spk2gender{$spkr2}) {
|
||||
$spk2gender{$spkr2} = $gender2;
|
||||
print GNDR "$spkr2"," $gender2\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wavs{$wav} |\n";
|
||||
print SPKR "$uttId"," $spkr2","\n";
|
||||
} else {
|
||||
print STDERR "Missing $wavs{$wav} for $wav\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(WAV) || die;
|
||||
close(SPKR) || die;
|
||||
close(GNDR) || die;
|
||||
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
|
||||
die "Error creating spk2utt file in directory $out_dir";
|
||||
}
|
||||
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
|
||||
die "Error fixing data dir $out_dir";
|
||||
}
|
||||
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
|
||||
die "Error validating directory $out_dir";
|
||||
}
|
||||
107
egs/callhome/eend_ola/local/make_swbd2_phase2.pl
Executable file
107
egs/callhome/eend_ola/local/make_swbd2_phase2.pl
Executable file
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/perl
|
||||
use warnings; #sed replacement for -w perl parameter
|
||||
#
|
||||
# Copyright 2013 Daniel Povey
|
||||
# Apache 2.0
|
||||
|
||||
if (@ARGV != 2) {
|
||||
print STDERR "Usage: $0 <path-to-LDC99S79> <path-to-output>\n";
|
||||
print STDERR "e.g. $0 /export/corpora5/LDC/LDC99S79 data/swbd2_phase2_train\n";
|
||||
exit(1);
|
||||
}
|
||||
($db_base, $out_dir) = @ARGV;
|
||||
|
||||
if (system("mkdir -p $out_dir")) {
|
||||
die "Error making directory $out_dir";
|
||||
}
|
||||
|
||||
open(CS, "<$db_base/DISC1/doc/callstat.tbl") || die "Could not open $db_base/DISC1/doc/callstat.tbl";
|
||||
open(CI, "<$db_base/DISC1/doc/callinfo.tbl") || die "Could not open $db_base/DISC1/doc/callinfo.tbl";
|
||||
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
|
||||
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
|
||||
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
|
||||
|
||||
@badAudio = ("3", "4");
|
||||
|
||||
$tmp_dir = "$out_dir/tmp";
|
||||
if (system("mkdir -p $tmp_dir") != 0) {
|
||||
die "Error making directory $tmp_dir";
|
||||
}
|
||||
|
||||
if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) {
|
||||
die "Error getting list of sph files";
|
||||
}
|
||||
|
||||
open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list";
|
||||
|
||||
while(<WAVLIST>) {
|
||||
chomp;
|
||||
$sph = $_;
|
||||
@t = split("/",$sph);
|
||||
@t1 = split("[./]",$t[$#t]);
|
||||
$uttId=$t1[0];
|
||||
$wav{$uttId} = $sph;
|
||||
}
|
||||
|
||||
while (<CS>) {
|
||||
$line = $_ ;
|
||||
$ci = <CI>;
|
||||
$ci = <CI>;
|
||||
@ci = split(",",$ci);
|
||||
$wav = $ci[0];
|
||||
@A = split(",", $line);
|
||||
if (/$wav/i ~~ @badAudio) {
|
||||
# do nothing
|
||||
} else {
|
||||
$spkr1= "sw_" . $A[2];
|
||||
$spkr2= "sw_" . $A[3];
|
||||
$gender1 = $A[4];
|
||||
$gender2 = $A[5];
|
||||
if ($gender1 eq "M") {
|
||||
$gender1 = "m";
|
||||
} elsif ($gender1 eq "F") {
|
||||
$gender1 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if ($gender2 eq "M") {
|
||||
$gender2 = "m";
|
||||
} elsif ($gender2 eq "F") {
|
||||
$gender2 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if (-e "$wav{$wav}") {
|
||||
$uttId = $spkr1 ."_" . $wav ."_1";
|
||||
if (!$spk2gender{$spkr1}) {
|
||||
$spk2gender{$spkr1} = $gender1;
|
||||
print GNDR "$spkr1"," $gender1\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wav{$wav} |\n";
|
||||
print SPKR "$uttId"," $spkr1","\n";
|
||||
|
||||
$uttId = $spkr2 . "_" . $wav ."_2";
|
||||
if (!$spk2gender{$spkr2}) {
|
||||
$spk2gender{$spkr2} = $gender2;
|
||||
print GNDR "$spkr2"," $gender2\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wav{$wav} |\n";
|
||||
print SPKR "$uttId"," $spkr2","\n";
|
||||
} else {
|
||||
print STDERR "Missing $wav{$wav} for $wav\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(WAV) || die;
|
||||
close(SPKR) || die;
|
||||
close(GNDR) || die;
|
||||
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
|
||||
die "Error creating spk2utt file in directory $out_dir";
|
||||
}
|
||||
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
|
||||
die "Error fixing data dir $out_dir";
|
||||
}
|
||||
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
|
||||
die "Error validating directory $out_dir";
|
||||
}
|
||||
102
egs/callhome/eend_ola/local/make_swbd2_phase3.pl
Executable file
102
egs/callhome/eend_ola/local/make_swbd2_phase3.pl
Executable file
@ -0,0 +1,102 @@
|
||||
#!/usr/bin/perl
|
||||
use warnings; #sed replacement for -w perl parameter
|
||||
#
|
||||
# Copyright 2013 Daniel Povey
|
||||
# Apache 2.0
|
||||
|
||||
if (@ARGV != 2) {
|
||||
print STDERR "Usage: $0 <path-to-LDC2002S06> <path-to-output>\n";
|
||||
print STDERR "e.g. $0 /export/corpora5/LDC/LDC2002S06 data/swbd2_phase3_train\n";
|
||||
exit(1);
|
||||
}
|
||||
($db_base, $out_dir) = @ARGV;
|
||||
|
||||
if (system("mkdir -p $out_dir")) {
|
||||
die "Error making directory $out_dir";
|
||||
}
|
||||
|
||||
open(CS, "<$db_base/DISC1/docs/callstat.tbl") || die "Could not open $db_base/DISC1/docs/callstat.tbl";
|
||||
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
|
||||
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
|
||||
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
|
||||
|
||||
@badAudio = ("3", "4");
|
||||
|
||||
$tmp_dir = "$out_dir/tmp";
|
||||
if (system("mkdir -p $tmp_dir") != 0) {
|
||||
die "Error making directory $tmp_dir";
|
||||
}
|
||||
|
||||
if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) {
|
||||
die "Error getting list of sph files";
|
||||
}
|
||||
|
||||
open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list";
|
||||
while(<WAVLIST>) {
|
||||
chomp;
|
||||
$sph = $_;
|
||||
@t = split("/",$sph);
|
||||
@t1 = split("[./]",$t[$#t]);
|
||||
$uttId=$t1[0];
|
||||
$wav{$uttId} = $sph;
|
||||
}
|
||||
|
||||
while (<CS>) {
|
||||
$line = $_ ;
|
||||
@A = split(",", $line);
|
||||
$wav = "sw_" . $A[0] ;
|
||||
if (/$wav/i ~~ @badAudio) {
|
||||
# do nothing
|
||||
} else {
|
||||
$spkr1= "sw_" . $A[3];
|
||||
$spkr2= "sw_" . $A[4];
|
||||
$gender1 = $A[5];
|
||||
$gender2 = $A[6];
|
||||
if ($gender1 eq "M") {
|
||||
$gender1 = "m";
|
||||
} elsif ($gender1 eq "F") {
|
||||
$gender1 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if ($gender2 eq "M") {
|
||||
$gender2 = "m";
|
||||
} elsif ($gender2 eq "F") {
|
||||
$gender2 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if (-e "$wav{$wav}") {
|
||||
$uttId = $spkr1 ."_" . $wav ."_1";
|
||||
if (!$spk2gender{$spkr1}) {
|
||||
$spk2gender{$spkr1} = $gender1;
|
||||
print GNDR "$spkr1"," $gender1\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wav{$wav} |\n";
|
||||
print SPKR "$uttId"," $spkr1","\n";
|
||||
|
||||
$uttId = $spkr2 . "_" . $wav ."_2";
|
||||
if (!$spk2gender{$spkr2}) {
|
||||
$spk2gender{$spkr2} = $gender2;
|
||||
print GNDR "$spkr2"," $gender2\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wav{$wav} |\n";
|
||||
print SPKR "$uttId"," $spkr2","\n";
|
||||
} else {
|
||||
print STDERR "Missing $wav{$wav} for $wav\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(WAV) || die;
|
||||
close(SPKR) || die;
|
||||
close(GNDR) || die;
|
||||
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
|
||||
die "Error creating spk2utt file in directory $out_dir";
|
||||
}
|
||||
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
|
||||
die "Error fixing data dir $out_dir";
|
||||
}
|
||||
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
|
||||
die "Error validating directory $out_dir";
|
||||
}
|
||||
83
egs/callhome/eend_ola/local/make_swbd_cellular1.pl
Normal file
83
egs/callhome/eend_ola/local/make_swbd_cellular1.pl
Normal file
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/perl
|
||||
use warnings; #sed replacement for -w perl parameter
|
||||
#
|
||||
# Copyright 2013 Daniel Povey
|
||||
# Apache 2.0
|
||||
|
||||
if (@ARGV != 2) {
|
||||
print STDERR "Usage: $0 <path-to-LDC2001S13> <path-to-output>\n";
|
||||
print STDERR "e.g. $0 /export/corpora5/LDC/LDC2001S13 data/swbd_cellular1_train\n";
|
||||
exit(1);
|
||||
}
|
||||
($db_base, $out_dir) = @ARGV;
|
||||
|
||||
if (system("mkdir -p $out_dir")) {
|
||||
die "Error making directory $out_dir";
|
||||
}
|
||||
|
||||
open(CS, "<$db_base/doc/swb_callstats.tbl") || die "Could not open $db_base/doc/swb_callstats.tbl";
|
||||
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
|
||||
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
|
||||
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
|
||||
|
||||
@badAudio = ("40019", "45024", "40022");
|
||||
|
||||
while (<CS>) {
|
||||
$line = $_ ;
|
||||
@A = split(",", $line);
|
||||
if (/$A[0]/i ~~ @badAudio) {
|
||||
# do nothing
|
||||
} else {
|
||||
$wav = "sw_" . $A[0];
|
||||
$spkr1= "sw_" . $A[1];
|
||||
$spkr2= "sw_" . $A[2];
|
||||
$gender1 = $A[3];
|
||||
$gender2 = $A[4];
|
||||
if ($A[3] eq "M") {
|
||||
$gender1 = "m";
|
||||
} elsif ($A[3] eq "F") {
|
||||
$gender1 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if ($A[4] eq "M") {
|
||||
$gender2 = "m";
|
||||
} elsif ($A[4] eq "F") {
|
||||
$gender2 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if (-e "$db_base/data/$wav.sph") {
|
||||
$uttId = $spkr1 . "-swbdc_" . $wav ."_1";
|
||||
if (!$spk2gender{$spkr1}) {
|
||||
$spk2gender{$spkr1} = $gender1;
|
||||
print GNDR "$spkr1"," $gender1\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/data/$wav.sph |\n";
|
||||
print SPKR "$uttId"," $spkr1","\n";
|
||||
|
||||
$uttId = $spkr2 . "-swbdc_" . $wav ."_2";
|
||||
if (!$spk2gender{$spkr2}) {
|
||||
$spk2gender{$spkr2} = $gender2;
|
||||
print GNDR "$spkr2"," $gender2\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/data/$wav.sph |\n";
|
||||
print SPKR "$uttId"," $spkr2","\n";
|
||||
} else {
|
||||
print STDERR "Missing $db_base/data/$wav.sph\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(WAV) || die;
|
||||
close(SPKR) || die;
|
||||
close(GNDR) || die;
|
||||
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
|
||||
die "Error creating spk2utt file in directory $out_dir";
|
||||
}
|
||||
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
|
||||
die "Error fixing data dir $out_dir";
|
||||
}
|
||||
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
|
||||
die "Error validating directory $out_dir";
|
||||
}
|
||||
83
egs/callhome/eend_ola/local/make_swbd_cellular2.pl
Executable file
83
egs/callhome/eend_ola/local/make_swbd_cellular2.pl
Executable file
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/perl
|
||||
use warnings; #sed replacement for -w perl parameter
|
||||
#
|
||||
# Copyright 2013 Daniel Povey
|
||||
# Apache 2.0
|
||||
|
||||
if (@ARGV != 2) {
|
||||
print STDERR "Usage: $0 <path-to-LDC2004S07> <path-to-output>\n";
|
||||
print STDERR "e.g. $0 /export/corpora5/LDC/LDC2004S07 data/swbd_cellular2_train\n";
|
||||
exit(1);
|
||||
}
|
||||
($db_base, $out_dir) = @ARGV;
|
||||
|
||||
if (system("mkdir -p $out_dir")) {
|
||||
die "Error making directory $out_dir";
|
||||
}
|
||||
|
||||
open(CS, "<$db_base/docs/swb_callstats.tbl") || die "Could not open $db_base/docs/swb_callstats.tbl";
|
||||
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
|
||||
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
|
||||
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
|
||||
|
||||
@badAudio=("45024", "40022");
|
||||
|
||||
while (<CS>) {
|
||||
$line = $_ ;
|
||||
@A = split(",", $line);
|
||||
if (/$A[0]/i ~~ @badAudio) {
|
||||
# do nothing
|
||||
} else {
|
||||
$wav = "sw_" . $A[0];
|
||||
$spkr1= "sw_" . $A[1];
|
||||
$spkr2= "sw_" . $A[2];
|
||||
$gender1 = $A[3];
|
||||
$gender2 = $A[4];
|
||||
if ($A[3] eq "M") {
|
||||
$gender1 = "m";
|
||||
} elsif ($A[3] eq "F") {
|
||||
$gender1 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if ($A[4] eq "M") {
|
||||
$gender2 = "m";
|
||||
} elsif ($A[4] eq "F") {
|
||||
$gender2 = "f";
|
||||
} else {
|
||||
die "Unknown Gender in $line";
|
||||
}
|
||||
if (-e "$db_base/data/$wav.sph") {
|
||||
$uttId = $spkr1 . "-swbdc_" . $wav ."_1";
|
||||
if (!$spk2gender{$spkr1}) {
|
||||
$spk2gender{$spkr1} = $gender1;
|
||||
print GNDR "$spkr1"," $gender1\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/data/$wav.sph |\n";
|
||||
print SPKR "$uttId"," $spkr1","\n";
|
||||
|
||||
$uttId = $spkr2 . "-swbdc_" . $wav ."_2";
|
||||
if (!$spk2gender{$spkr2}) {
|
||||
$spk2gender{$spkr2} = $gender2;
|
||||
print GNDR "$spkr2"," $gender2\n";
|
||||
}
|
||||
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/data/$wav.sph |\n";
|
||||
print SPKR "$uttId"," $spkr2","\n";
|
||||
} else {
|
||||
print STDERR "Missing $db_base/data/$wav.sph\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(WAV) || die;
|
||||
close(SPKR) || die;
|
||||
close(GNDR) || die;
|
||||
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
|
||||
die "Error creating spk2utt file in directory $out_dir";
|
||||
}
|
||||
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
|
||||
die "Error fixing data dir $out_dir";
|
||||
}
|
||||
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
|
||||
die "Error validating directory $out_dir";
|
||||
}
|
||||
28
egs/callhome/eend_ola/local/model_averaging.py
Executable file
28
egs/callhome/eend_ola/local/model_averaging.py
Executable file
@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def average_model(input_files, output_file):
|
||||
output_model = {}
|
||||
for ckpt_path in input_files:
|
||||
model_params = torch.load(ckpt_path, map_location="cpu")
|
||||
for key, value in model_params.items():
|
||||
if key not in output_model:
|
||||
output_model[key] = value
|
||||
else:
|
||||
output_model[key] += value
|
||||
for key in output_model.keys():
|
||||
output_model[key] /= len(input_files)
|
||||
torch.save(output_model, output_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("output_file")
|
||||
parser.add_argument("input_files", nargs='+')
|
||||
args = parser.parse_args()
|
||||
|
||||
average_model(args.input_files, args.output_file)
|
||||
97
egs/callhome/eend_ola/local/parse_options.sh
Executable file
97
egs/callhome/eend_ola/local/parse_options.sh
Executable file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
||||
# Arnab Ghoshal, Karel Vesely
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Parse command-line options.
|
||||
# To be sourced by another script (as in ". parse_options.sh").
|
||||
# Option format is: --option-name arg
|
||||
# and shell variable "option_name" gets set to value "arg."
|
||||
# The exception is --help, which takes no arguments, but prints the
|
||||
# $help_message variable (if defined).
|
||||
|
||||
|
||||
###
|
||||
### The --config file options have lower priority to command line
|
||||
### options, so we need to import them first...
|
||||
###
|
||||
|
||||
# Now import all the configs specified by command-line, in left-to-right order
|
||||
for ((argpos=1; argpos<$#; argpos++)); do
|
||||
if [ "${!argpos}" == "--config" ]; then
|
||||
argpos_plus1=$((argpos+1))
|
||||
config=${!argpos_plus1}
|
||||
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
||||
. $config # source the config file.
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
###
|
||||
### Now we process the command line options
|
||||
###
|
||||
while true; do
|
||||
[ -z "${1:-}" ] && break; # break if there are no arguments
|
||||
case "$1" in
|
||||
# If the enclosing script is called with --help option, print the help
|
||||
# message and exit. Scripts should put help messages in $help_message
|
||||
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
||||
else printf "$help_message\n" 1>&2 ; fi;
|
||||
exit 0 ;;
|
||||
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
||||
exit 1 ;;
|
||||
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
||||
# then work out the variable name as $name, which will equal "foo_bar".
|
||||
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
||||
# Next we test whether the variable in question is undefned-- if so it's
|
||||
# an invalid option and we die. Note: $0 evaluates to the name of the
|
||||
# enclosing script.
|
||||
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
||||
# is undefined. We then have to wrap this test inside "eval" because
|
||||
# foo_bar is itself inside a variable ($name).
|
||||
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
||||
|
||||
oldval="`eval echo \\$$name`";
|
||||
# Work out whether we seem to be expecting a Boolean argument.
|
||||
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
||||
was_bool=true;
|
||||
else
|
||||
was_bool=false;
|
||||
fi
|
||||
|
||||
# Set the variable to the right value-- the escaped quotes make it work if
|
||||
# the option had spaces, like --cmd "queue.pl -sync y"
|
||||
eval $name=\"$2\";
|
||||
|
||||
# Check that Boolean-valued arguments are really Boolean.
|
||||
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
||||
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
||||
exit 1;
|
||||
fi
|
||||
shift 2;
|
||||
;;
|
||||
*) break;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
# Check for an empty argument to the --cmd option, which can easily occur as a
|
||||
# result of scripting errors.
|
||||
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
||||
|
||||
|
||||
true; # so this script returns exit code 0.
|
||||
145
egs/callhome/eend_ola/local/random_mixture.py
Executable file
145
egs/callhome/eend_ola/local/random_mixture.py
Executable file
@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
This script generates random multi-talker mixtures for diarization.
|
||||
It generates a scp-like outputs: lines of "[recid] [json]".
|
||||
recid: recording id of mixture
|
||||
serial numbers like mix_0000001, mix_0000002, ...
|
||||
json: mixture configuration formatted in "one-line"
|
||||
The json format is as following:
|
||||
{
|
||||
'speakers':[ # list of speakers
|
||||
{
|
||||
'spkid': 'Name', # speaker id
|
||||
'rir': '/rirdir/rir.wav', # wav_rxfilename of room impulse response
|
||||
'utts': [ # list of wav_rxfilenames of utterances
|
||||
'/wavdir/utt1.wav',
|
||||
'/wavdir/utt2.wav',...],
|
||||
'intervals': [1.2, 3.4, ...] # list of silence durations before utterances
|
||||
}, ... ],
|
||||
'noise': '/noisedir/noise.wav' # wav_rxfilename of background noise
|
||||
'snr': 15.0, # SNR for mixing background noise
|
||||
'recid': 'mix_000001' # recording id of the mixture
|
||||
}
|
||||
|
||||
Usage:
|
||||
common/random_mixture.py \
|
||||
--n_mixtures=10000 \ # number of mixtures
|
||||
data/voxceleb1_train \ # kaldi-style data dir of utterances
|
||||
data/musan_noise_bg \ # background noises
|
||||
data/simu_rirs \ # room impulse responses
|
||||
> mixture.scp # output scp-like file
|
||||
|
||||
The actual data dir and wav files are generated using make_mixture.py:
|
||||
common/make_mixture.py \
|
||||
mixture.scp \ # scp-like file for mixture
|
||||
data/mixture \ # output data dir
|
||||
wav/mixture # output wav dir
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from funasr.modules.eend_ola.utils import kaldi_data
|
||||
import random
|
||||
import numpy as np
|
||||
import json
|
||||
import itertools
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('data_dir',
|
||||
help='data dir of single-speaker recordings')
|
||||
parser.add_argument('noise_dir',
|
||||
help='data dir of background noise recordings')
|
||||
parser.add_argument('rir_dir',
|
||||
help='data dir of room impulse responses')
|
||||
parser.add_argument('--n_mixtures', type=int, default=10,
|
||||
help='number of mixture recordings')
|
||||
parser.add_argument('--n_speakers', type=int, default=4,
|
||||
help='number of speakers in a mixture')
|
||||
parser.add_argument('--min_utts', type=int, default=10,
|
||||
help='minimum number of uttenraces per speaker')
|
||||
parser.add_argument('--max_utts', type=int, default=20,
|
||||
help='maximum number of utterances per speaker')
|
||||
parser.add_argument('--sil_scale', type=float, default=10.0,
|
||||
help='average silence time')
|
||||
parser.add_argument('--noise_snrs', default="10:15:20",
|
||||
help='colon-delimited SNRs for background noises')
|
||||
parser.add_argument('--random_seed', type=int, default=777,
|
||||
help='random seed')
|
||||
parser.add_argument('--speech_rvb_probability', type=float, default=1,
|
||||
help='reverb probability')
|
||||
args = parser.parse_args()
|
||||
|
||||
random.seed(args.random_seed)
|
||||
np.random.seed(args.random_seed)
|
||||
|
||||
# load list of wav files from kaldi-style data dirs
|
||||
wavs = kaldi_data.load_wav_scp(
|
||||
os.path.join(args.data_dir, 'wav.scp'))
|
||||
noises = kaldi_data.load_wav_scp(
|
||||
os.path.join(args.noise_dir, 'wav.scp'))
|
||||
rirs = kaldi_data.load_wav_scp(
|
||||
os.path.join(args.rir_dir, 'wav.scp'))
|
||||
|
||||
# spk2utt is used for counting number of utterances per speaker
|
||||
spk2utt = kaldi_data.load_spk2utt(
|
||||
os.path.join(args.data_dir, 'spk2utt'))
|
||||
|
||||
segments = kaldi_data.load_segments_hash(
|
||||
os.path.join(args.data_dir, 'segments'))
|
||||
|
||||
# choice lists for random sampling
|
||||
all_speakers = list(spk2utt.keys())
|
||||
all_noises = list(noises.keys())
|
||||
all_rirs = list(rirs.keys())
|
||||
noise_snrs = [float(x) for x in args.noise_snrs.split(':')]
|
||||
|
||||
mixtures = []
|
||||
for it in range(args.n_mixtures):
|
||||
# recording ids are mix_0000001, mix_0000002, ...
|
||||
recid = 'mix_{:07d}'.format(it + 1)
|
||||
# randomly select speakers, a background noise and a SNR
|
||||
speakers = random.sample(all_speakers, args.n_speakers)
|
||||
noise = random.choice(all_noises)
|
||||
noise_snr = random.choice(noise_snrs)
|
||||
mixture = {'speakers': []}
|
||||
for speaker in speakers:
|
||||
# randomly select the number of utterances
|
||||
n_utts = np.random.randint(args.min_utts, args.max_utts + 1)
|
||||
# utts = spk2utt[speaker][:n_utts]
|
||||
cycle_utts = itertools.cycle(spk2utt[speaker])
|
||||
# random start utterance
|
||||
roll = np.random.randint(0, len(spk2utt[speaker]))
|
||||
for i in range(roll):
|
||||
next(cycle_utts)
|
||||
utts = [next(cycle_utts) for i in range(n_utts)]
|
||||
# randomly select wait time before appending utterance
|
||||
intervals = np.random.exponential(args.sil_scale, size=n_utts)
|
||||
# randomly select a room impulse response
|
||||
if random.random() < args.speech_rvb_probability:
|
||||
rir = rirs[random.choice(all_rirs)]
|
||||
else:
|
||||
rir = None
|
||||
if segments is not None:
|
||||
utts = [segments[utt] for utt in utts]
|
||||
utts = [(wavs[rec], st, et) for (rec, st, et) in utts]
|
||||
mixture['speakers'].append({
|
||||
'spkid': speaker,
|
||||
'rir': rir,
|
||||
'utts': utts,
|
||||
'intervals': intervals.tolist()
|
||||
})
|
||||
else:
|
||||
mixture['speakers'].append({
|
||||
'spkid': speaker,
|
||||
'rir': rir,
|
||||
'utts': [wavs[utt] for utt in utts],
|
||||
'intervals': intervals.tolist()
|
||||
})
|
||||
mixture['noise'] = noises[noise]
|
||||
mixture['snr'] = noise_snr
|
||||
mixture['recid'] = recid
|
||||
print(recid, json.dumps(mixture))
|
||||
235
egs/callhome/eend_ola/local/run_prepare_shared_eda.sh
Executable file
235
egs/callhome/eend_ola/local/run_prepare_shared_eda.sh
Executable file
@ -0,0 +1,235 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita, Shota Horiguchi)
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# This script prepares kaldi-style data sets shared with different experiments
|
||||
# - data/xxxx
|
||||
# callhome, sre, swb2, and swb_cellular datasets
|
||||
# - data/simu_${simu_outputs}
|
||||
# simulation mixtures generated with various options
|
||||
|
||||
stage=0
|
||||
|
||||
# Modify corpus directories
|
||||
# - callhome_dir
|
||||
# CALLHOME (LDC2001S97)
|
||||
# - swb2_phase1_train
|
||||
# Switchboard-2 Phase 1 (LDC98S75)
|
||||
# - data_root
|
||||
# LDC99S79, LDC2002S06, LDC2001S13, LDC2004S07,
|
||||
# LDC2006S44, LDC2011S01, LDC2011S04, LDC2011S09,
|
||||
# LDC2011S10, LDC2012S01, LDC2011S05, LDC2011S08
|
||||
# - musan_root
|
||||
# MUSAN corpus (https://www.openslr.org/17/)
|
||||
callhome_dir=
|
||||
swb2_phase1_train=
|
||||
data_root=
|
||||
musan_root=
|
||||
# Modify simulated data storage area.
|
||||
# This script distributes simulated data under these directories
|
||||
simu_actual_dirs=(
|
||||
./s05/$USER/diarization-data
|
||||
./s08/$USER/diarization-data
|
||||
./s09/$USER/diarization-data
|
||||
)
|
||||
|
||||
# data preparation options
|
||||
max_jobs_run=4
|
||||
sad_num_jobs=30
|
||||
sad_opts="--extra-left-context 79 --extra-right-context 21 --frames-per-chunk 150 --extra-left-context-initial 0 --extra-right-context-final 0 --acwt 0.3"
|
||||
sad_graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0"
|
||||
sad_priors_opts="--sil-scale=0.1"
|
||||
|
||||
# simulation options
|
||||
simu_opts_overlap=yes
|
||||
simu_opts_num_speaker_array=(1 2 3 4)
|
||||
simu_opts_sil_scale_array=(2 2 5 9)
|
||||
simu_opts_rvb_prob=0.5
|
||||
simu_opts_num_train=100000
|
||||
simu_opts_min_utts=10
|
||||
simu_opts_max_utts=20
|
||||
|
||||
simu_cmd="run.pl"
|
||||
train_cmd="run.pl"
|
||||
random_mixture_cmd="run.pl"
|
||||
make_mixture_cmd="run.pl"
|
||||
|
||||
. parse_options.sh || exit
|
||||
|
||||
if [ $stage -le 0 ]; then
|
||||
echo "prepare kaldi-style datasets"
|
||||
# Prepare CALLHOME dataset. This will be used to evaluation.
|
||||
if ! validate_data_dir.sh --no-text --no-feats data/callhome1_spkall \
|
||||
|| ! validate_data_dir.sh --no-text --no-feats data/callhome2_spkall; then
|
||||
# imported from https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v1
|
||||
local/make_callhome.sh $callhome_dir data
|
||||
# Generate two-speaker subsets
|
||||
for dset in callhome1 callhome2; do
|
||||
# Extract two-speaker recordings in wav.scp
|
||||
copy_data_dir.sh data/${dset} data/${dset}_spkall
|
||||
# Regenerate segments file from fullref.rttm
|
||||
# $2: recid, $4: start_time, $5: duration, $8: speakerid
|
||||
awk '{printf "%s_%s_%07d_%07d %s %.2f %.2f\n", \
|
||||
$2, $8, $4*100, ($4+$5)*100, $2, $4, $4+$5}' \
|
||||
data/callhome/fullref.rttm | sort > data/${dset}_spkall/segments
|
||||
utils/fix_data_dir.sh data/${dset}_spkall
|
||||
# Speaker ID is '[recid]_[speakerid]
|
||||
awk '{split($1,A,"_"); printf "%s %s_%s\n", $1, A[1], A[2]}' \
|
||||
data/${dset}_spkall/segments > data/${dset}_spkall/utt2spk
|
||||
utils/fix_data_dir.sh data/${dset}_spkall
|
||||
# Generate rttm files for scoring
|
||||
steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \
|
||||
data/${dset}_spkall/utt2spk data/${dset}_spkall/segments \
|
||||
data/${dset}_spkall/rttm
|
||||
utils/data/get_reco2dur.sh data/${dset}_spkall
|
||||
done
|
||||
fi
|
||||
# Prepare a collection of NIST SRE and SWB data. This will be used to train,
|
||||
if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_comb; then
|
||||
local/make_sre.sh $data_root data
|
||||
# Prepare SWB for x-vector DNN training.
|
||||
local/make_swbd2_phase1.pl $swb2_phase1_train \
|
||||
data/swbd2_phase1_train
|
||||
local/make_swbd2_phase2.pl $data_root/LDC99S79 \
|
||||
data/swbd2_phase2_train
|
||||
local/make_swbd2_phase3.pl $data_root/LDC2002S06 \
|
||||
data/swbd2_phase3_train
|
||||
local/make_swbd_cellular1.pl $data_root/LDC2001S13 \
|
||||
data/swbd_cellular1_train
|
||||
local/make_swbd_cellular2.pl $data_root/LDC2004S07 \
|
||||
data/swbd_cellular2_train
|
||||
# Combine swb and sre data
|
||||
utils/combine_data.sh data/swb_sre_comb \
|
||||
data/swbd_cellular1_train data/swbd_cellular2_train \
|
||||
data/swbd2_phase1_train \
|
||||
data/swbd2_phase2_train data/swbd2_phase3_train data/sre
|
||||
fi
|
||||
# musan data. "back-ground
|
||||
if ! validate_data_dir.sh --no-text --no-feats data/musan_noise_bg; then
|
||||
local/make_musan.sh $musan_root data
|
||||
utils/copy_data_dir.sh data/musan_noise data/musan_noise_bg
|
||||
awk '{if(NR>1) print $1,$1}' $musan_root/noise/free-sound/ANNOTATIONS > data/musan_noise_bg/utt2spk
|
||||
utils/fix_data_dir.sh data/musan_noise_bg
|
||||
fi
|
||||
# simu rirs 8k
|
||||
if ! validate_data_dir.sh --no-text --no-feats data/simu_rirs_8k; then
|
||||
mkdir -p data/simu_rirs_8k
|
||||
# if [ ! -e sim_rir_8k.zip ]; then
|
||||
# wget --no-check-certificate http://www.openslr.org/resources/26/sim_rir_8k.zip
|
||||
# fi
|
||||
unzip sim_rir_8k.zip -d data/sim_rir_8k
|
||||
find $PWD/data/sim_rir_8k -iname "*.wav" \
|
||||
| awk '{n=split($1,A,/[\/\.]/); print A[n-3]"_"A[n-1], $1}' \
|
||||
| sort > data/simu_rirs_8k/wav.scp
|
||||
awk '{print $1, $1}' data/simu_rirs_8k/wav.scp > data/simu_rirs_8k/utt2spk
|
||||
utils/fix_data_dir.sh data/simu_rirs_8k
|
||||
fi
|
||||
# Automatic segmentation using pretrained SAD model
|
||||
# it will take one day using 30 CPU jobs:
|
||||
# make_mfcc: 1 hour, compute_output: 18 hours, decode: 0.5 hours
|
||||
sad_nnet_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a
|
||||
sad_work_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a
|
||||
if ! validate_data_dir.sh --no-text $sad_work_dir/swb_sre_comb_seg; then
|
||||
if [ ! -d exp/segmentation_1a ]; then
|
||||
# wget http://kaldi-asr.org/models/4/0004_tdnn_stats_asr_sad_1a.tar.gz
|
||||
tar zxf 0004_tdnn_stats_asr_sad_1a.tar.gz
|
||||
fi
|
||||
steps/segmentation/detect_speech_activity.sh \
|
||||
--nj $sad_num_jobs \
|
||||
--graph-opts "$sad_graph_opts" \
|
||||
--transform-probs-opts "$sad_priors_opts" $sad_opts \
|
||||
data/swb_sre_comb $sad_nnet_dir mfcc_hires $sad_work_dir \
|
||||
$sad_work_dir/swb_sre_comb || exit 1
|
||||
fi
|
||||
# Extract >1.5 sec segments and split into train/valid sets
|
||||
if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_cv; then
|
||||
copy_data_dir.sh data/swb_sre_comb data/swb_sre_comb_seg
|
||||
awk '$4-$3>1.5{print;}' $sad_work_dir/swb_sre_comb_seg/segments > data/swb_sre_comb_seg/segments
|
||||
cp $sad_work_dir/swb_sre_comb_seg/{utt2spk,spk2utt} data/swb_sre_comb_seg
|
||||
fix_data_dir.sh data/swb_sre_comb_seg
|
||||
utils/subset_data_dir_tr_cv.sh data/swb_sre_comb_seg data/swb_sre_tr data/swb_sre_cv
|
||||
fi
|
||||
fi
|
||||
|
||||
simudir=data/simu
|
||||
if [ $stage -le 1 ]; then
|
||||
echo "simulation of mixture"
|
||||
mkdir -p $simudir/.work
|
||||
random_mixture_cmd=local/random_mixture.py
|
||||
make_mixture_cmd=local/make_mixture.py
|
||||
|
||||
for ((i=0; i<${#simu_opts_sil_scale_array[@]}; ++i)); do
|
||||
simu_opts_num_speaker=${simu_opts_num_speaker_array[i]}
|
||||
simu_opts_sil_scale=${simu_opts_sil_scale_array[i]}
|
||||
for dset in swb_sre_tr swb_sre_cv; do
|
||||
if [ "$dset" == "swb_sre_tr" ]; then
|
||||
n_mixtures=${simu_opts_num_train}
|
||||
else
|
||||
n_mixtures=500
|
||||
fi
|
||||
simuid=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures}
|
||||
# check if you have the simulation
|
||||
if ! validate_data_dir.sh --no-text --no-feats $simudir/data/$simuid; then
|
||||
# random mixture generation
|
||||
$train_cmd $simudir/.work/random_mixture_$simuid.log \
|
||||
$random_mixture_cmd --n_speakers $simu_opts_num_speaker --n_mixtures $n_mixtures \
|
||||
--speech_rvb_probability $simu_opts_rvb_prob \
|
||||
--sil_scale $simu_opts_sil_scale \
|
||||
data/$dset data/musan_noise_bg data/simu_rirs_8k \
|
||||
\> $simudir/.work/mixture_$simuid.scp
|
||||
nj=64
|
||||
mkdir -p $simudir/wav/$simuid
|
||||
# distribute simulated data to $simu_actual_dir
|
||||
split_scps=
|
||||
for n in $(seq $nj); do
|
||||
split_scps="$split_scps $simudir/.work/mixture_$simuid.$n.scp"
|
||||
mkdir -p $simudir/.work/data_$simuid.$n
|
||||
actual=${simu_actual_dirs[($n-1)%${#simu_actual_dirs[@]}]}/$simudir/wav/$simuid/$n
|
||||
mkdir -p $actual
|
||||
ln -nfs $actual $simudir/wav/$simuid/$n
|
||||
done
|
||||
utils/split_scp.pl $simudir/.work/mixture_$simuid.scp $split_scps || exit 1
|
||||
|
||||
$simu_cmd --max-jobs-run 64 JOB=1:$nj $simudir/.work/make_mixture_$simuid.JOB.log \
|
||||
$make_mixture_cmd --rate=8000 \
|
||||
$simudir/.work/mixture_$simuid.JOB.scp \
|
||||
$simudir/.work/data_$simuid.JOB $simudir/wav/$simuid/JOB
|
||||
utils/combine_data.sh $simudir/data/$simuid $simudir/.work/data_$simuid.*
|
||||
steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \
|
||||
$simudir/data/$simuid/utt2spk $simudir/data/$simuid/segments \
|
||||
$simudir/data/$simuid/rttm
|
||||
utils/data/get_reco2dur.sh $simudir/data/$simuid
|
||||
fi
|
||||
simuid_concat=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures}
|
||||
mkdir -p $simudir/data/$simuid_concat
|
||||
for f in `ls -F $simudir/data/$simuid | grep -v "/"`; do
|
||||
cat $simudir/data/$simuid/$f >> $simudir/data/$simuid_concat/$f
|
||||
done
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ]; then
|
||||
# compose eval/callhome2_spkall
|
||||
eval_set=data/eval/callhome2_spkall
|
||||
if ! validate_data_dir.sh --no-text --no-feats $eval_set; then
|
||||
utils/copy_data_dir.sh data/callhome2_spkall $eval_set
|
||||
cp data/callhome2_spkall/rttm $eval_set/rttm
|
||||
awk -v dstdir=wav/eval/callhome2_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome2_spkall/wav.scp > $eval_set/wav.scp
|
||||
mkdir -p wav/eval/callhome2_spkall
|
||||
wav-copy scp:data/callhome2_spkall/wav.scp scp:$eval_set/wav.scp
|
||||
utils/data/get_reco2dur.sh $eval_set
|
||||
fi
|
||||
|
||||
# compose eval/callhome1_spkall
|
||||
adapt_set=data/eval/callhome1_spkall
|
||||
if ! validate_data_dir.sh --no-text --no-feats $adapt_set; then
|
||||
utils/copy_data_dir.sh data/callhome1_spkall $adapt_set
|
||||
cp data/callhome1_spkall/rttm $adapt_set/rttm
|
||||
awk -v dstdir=wav/eval/callhome1_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome1_spkall/wav.scp > $adapt_set/wav.scp
|
||||
mkdir -p wav/eval/callhome1_spkall
|
||||
wav-copy scp:data/callhome1_spkall/wav.scp scp:$adapt_set/wav.scp
|
||||
utils/data/get_reco2dur.sh $adapt_set
|
||||
fi
|
||||
fi
|
||||
117
egs/callhome/eend_ola/local/split.py
Normal file
117
egs/callhome/eend_ola/local/split.py
Normal file
@ -0,0 +1,117 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('root_path', help='raw data path')
|
||||
args = parser.parse_args()
|
||||
|
||||
root_path = args.root_path
|
||||
work_path = os.path.join(root_path, ".work")
|
||||
scp_files = os.listdir(work_path)
|
||||
|
||||
reco2dur_dict = {}
|
||||
with open(os.path.join(root_path, 'reco2dur')) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
reco2dur_dict[parts[0]] = parts[1]
|
||||
|
||||
spk2utt_dict = {}
|
||||
with open(os.path.join(root_path, 'spk2utt')) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
spk = parts[0]
|
||||
utts = parts[1:]
|
||||
for utt in utts:
|
||||
tmp = utt.split('data')
|
||||
rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
|
||||
if rec in spk2utt_dict.keys():
|
||||
spk2utt_dict[rec].append((spk, utt))
|
||||
else:
|
||||
spk2utt_dict[rec] = []
|
||||
spk2utt_dict[rec].append((spk, utt))
|
||||
|
||||
segment_dict = {}
|
||||
with open(os.path.join(root_path, 'segments')) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
if parts[1] in segment_dict.keys():
|
||||
segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
|
||||
else:
|
||||
segment_dict[parts[1]] = []
|
||||
segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
|
||||
|
||||
utt2spk_dict = {}
|
||||
with open(os.path.join(root_path, 'utt2spk')) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
utt = parts[0]
|
||||
tmp = utt.split('data')
|
||||
rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
|
||||
if rec in utt2spk_dict.keys():
|
||||
utt2spk_dict[rec].append((parts[0], parts[1]))
|
||||
else:
|
||||
utt2spk_dict[rec] = []
|
||||
utt2spk_dict[rec].append((parts[0], parts[1]))
|
||||
|
||||
for file in scp_files:
|
||||
scp_file = os.path.join(work_path, file)
|
||||
idx = scp_file.split('.')[-1]
|
||||
reco2dur_file = os.path.join(work_path, 'reco2dur.{}'.format(str(idx)))
|
||||
spk2utt_file = os.path.join(work_path, 'spk2utt.{}'.format(str(idx)))
|
||||
segment_file = os.path.join(work_path, 'segments.{}'.format(str(idx)))
|
||||
utt2spk_file = os.path.join(work_path, 'utt2spk.{}'.format(str(idx)))
|
||||
|
||||
fpp = open(scp_file)
|
||||
scp_lines = fpp.readlines()
|
||||
keys = []
|
||||
for line in scp_lines:
|
||||
name = line.strip().split()[0]
|
||||
keys.append(name)
|
||||
|
||||
with open(reco2dur_file, 'w') as f:
|
||||
lines = []
|
||||
for key in keys:
|
||||
string = key + ' ' + reco2dur_dict[key]
|
||||
lines.append(string + '\n')
|
||||
lines[-1] = lines[-1][:-1]
|
||||
f.writelines(lines)
|
||||
|
||||
with open(spk2utt_file, 'w') as f:
|
||||
lines = []
|
||||
for key in keys:
|
||||
items = spk2utt_dict[key]
|
||||
for item in items:
|
||||
string = item[0]
|
||||
for it in item[1:]:
|
||||
string += ' '
|
||||
string += it
|
||||
lines.append(string + '\n')
|
||||
lines[-1] = lines[-1][:-1]
|
||||
f.writelines(lines)
|
||||
|
||||
with open(segment_file, 'w') as f:
|
||||
lines = []
|
||||
for key in keys:
|
||||
items = segment_dict[key]
|
||||
for item in items:
|
||||
string = item[0] + ' ' + key + ' ' + item[1] + ' ' + item[2]
|
||||
lines.append(string + '\n')
|
||||
lines[-1] = lines[-1][:-1]
|
||||
f.writelines(lines)
|
||||
|
||||
with open(utt2spk_file, 'w') as f:
|
||||
lines = []
|
||||
for key in keys:
|
||||
items = utt2spk_dict[key]
|
||||
for item in items:
|
||||
string = item[0] + ' ' + item[1]
|
||||
lines.append(string + '\n')
|
||||
lines[-1] = lines[-1][:-1]
|
||||
f.writelines(lines)
|
||||
|
||||
fpp.close()
|
||||
13
egs/callhome/eend_ola/path.sh
Executable file
13
egs/callhome/eend_ola/path.sh
Executable file
@ -0,0 +1,13 @@
|
||||
export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
# kaldi-related
|
||||
export KALDI_ROOT=
|
||||
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
|
||||
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sph2pipe_v2.5:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH
|
||||
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
|
||||
. $KALDI_ROOT/tools/config/common_path.sh
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=../../../:$PYTHONPATH
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
324
egs/callhome/eend_ola/run.sh
Normal file
324
egs/callhome/eend_ola/run.sh
Normal file
@ -0,0 +1,324 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
# machines configuration
|
||||
CUDA_VISIBLE_DEVICES="0"
|
||||
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
count=1
|
||||
|
||||
# general configuration
|
||||
dump_cmd=utils/run.pl
|
||||
nj=64
|
||||
|
||||
# feature configuration
|
||||
data_dir="./data"
|
||||
simu_feats_dir=$data_dir/ark_data/dump/simu_data/data
|
||||
simu_feats_dir_chunk2000=$data_dir/ark_data/dump/simu_data_chunk2000/data
|
||||
callhome_feats_dir_chunk2000=$data_dir/ark_data/dump/callhome_chunk2000/data
|
||||
simu_train_dataset=train
|
||||
simu_valid_dataset=dev
|
||||
callhome_train_dataset=callhome1_spkall
|
||||
callhome_valid_dataset=callhome2_spkall
|
||||
|
||||
# model average
|
||||
simu_average_2spkr_start=91
|
||||
simu_average_2spkr_end=100
|
||||
simu_average_allspkr_start=16
|
||||
simu_average_allspkr_end=25
|
||||
callhome_average_start=91
|
||||
callhome_average_end=100
|
||||
|
||||
exp_dir="."
|
||||
input_size=345
|
||||
stage=1
|
||||
stop_stage=5
|
||||
|
||||
# exp tag
|
||||
tag="exp1"
|
||||
|
||||
. local/parse_options.sh || exit 1;
|
||||
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
simu_2spkr_diar_config=conf/train_diar_eend_ola_simu_2spkr.yaml
|
||||
simu_allspkr_diar_config=conf/train_diar_eend_ola_simu_allspkr.yaml
|
||||
simu_allspkr_chunk2000_diar_config=conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml
|
||||
callhome_diar_config=conf/train_diar_eend_ola_callhome_chunk2000.yaml
|
||||
simu_2spkr_model_dir="baseline_$(basename "${simu_2spkr_diar_config}" .yaml)_${tag}"
|
||||
simu_allspkr_model_dir="baseline_$(basename "${simu_allspkr_diar_config}" .yaml)_${tag}"
|
||||
simu_allspkr_chunk2000_model_dir="baseline_$(basename "${simu_allspkr_chunk2000_diar_config}" .yaml)_${tag}"
|
||||
callhome_model_dir="baseline_$(basename "${callhome_diar_config}" .yaml)_${tag}"
|
||||
|
||||
# simulate mixture data for training and inference
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
echo "stage -1: Simulate mixture data for training and inference"
|
||||
echo "The detail can be found in https://github.com/hitachi-speech/EEND"
|
||||
echo "Before running this step, you should download and compile kaldi and set KALDI_ROOT in this script and path.sh"
|
||||
echo "This stage may take a long time, please waiting..."
|
||||
KALDI_ROOT=
|
||||
ln -s $KALDI_ROOT/egs/wsj/s5/steps steps
|
||||
ln -s $KALDI_ROOT/egs/wsj/s5/utils utils
|
||||
local/run_prepare_shared_eda.sh
|
||||
fi
|
||||
|
||||
# Prepare data for training and inference
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "stage 0: Prepare data for training and inference"
|
||||
simu_opts_num_speaker_array=(1 2 3 4)
|
||||
simu_opts_sil_scale_array=(2 2 5 9)
|
||||
simu_opts_num_train=100000
|
||||
|
||||
# for simulated data of chunk500 and chunk2000
|
||||
for dset in swb_sre_cv swb_sre_tr; do
|
||||
if [ "$dset" == "swb_sre_tr" ]; then
|
||||
n_mixtures=${simu_opts_num_train}
|
||||
dataset=train
|
||||
else
|
||||
n_mixtures=500
|
||||
dataset=dev
|
||||
fi
|
||||
simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures}
|
||||
mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work
|
||||
split_scps=
|
||||
for n in $(seq $nj); do
|
||||
split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.scp.$n"
|
||||
done
|
||||
utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1
|
||||
python local/split.py ${data_dir}/simu/data/${simu_data_dir}
|
||||
# for chunk_size=500
|
||||
output_dir=${data_dir}/ark_data/dump/simu_data/$dataset
|
||||
mkdir -p $output_dir/.logs
|
||||
$dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \
|
||||
python local/dump_feature.py \
|
||||
--data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \
|
||||
--output_dir $output_dir \
|
||||
--index JOB
|
||||
mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset
|
||||
cat ${data_dir}/ark_data/dump/simu_data/$dataset/feature.scp.* > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feature.scp
|
||||
cat ${data_dir}/ark_data/dump/simu_data/$dataset/label.scp.* > ${data_dir}/ark_data/dump/simu_data/data/$dataset/label.scp
|
||||
paste -d" " ${data_dir}/ark_data/dump/simu_data/data/$dataset/feature.scp <(cut -f2 -d" " ${data_dir}/ark_data/dump/simu_data/data/$dataset/label.scp) > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats.scp
|
||||
grep "ns2" ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats_2spkr.scp
|
||||
# for chunk_size=2000
|
||||
output_dir=${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset
|
||||
mkdir -p $output_dir/.logs
|
||||
$dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \
|
||||
python local/dump_feature.py \
|
||||
--data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \
|
||||
--output_dir $output_dir \
|
||||
--index JOB \
|
||||
--num_frames 2000
|
||||
mkdir -p ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset
|
||||
cat ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset/feature.scp.* > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feature.scp
|
||||
cat ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset/label.scp.* > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/label.scp
|
||||
paste -d" " ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feature.scp <(cut -f2 -d" " ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/label.scp) > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats.scp
|
||||
done
|
||||
|
||||
# for callhome data
|
||||
for dset in callhome1_spkall callhome2_spkall; do
|
||||
find $data_dir/eval/$dset -maxdepth 1 -type f -exec cp {} {}.1 \;
|
||||
output_dir=${data_dir}/ark_data/dump/callhome_chunk2000/$dset
|
||||
mkdir -p $output_dir
|
||||
python local/dump_feature.py \
|
||||
--data_dir $data_dir/eval/$dset \
|
||||
--output_dir $output_dir \
|
||||
--index 1 \
|
||||
--num_frames 2000
|
||||
mkdir -p ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset
|
||||
paste -d" " ${data_dir}/ark_data/dump/callhome_chunk2000/$dset/feature.scp.1 <(cut -f2 -d" " ${data_dir}/ark_data/dump/callhome_chunk2000/$dset/label.scp.1) > ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset/feats.scp
|
||||
done
|
||||
fi
|
||||
|
||||
# Training on simulated two-speaker data
|
||||
world_size=$gpu_num
|
||||
simu_2spkr_ave_id=avg${simu_average_2spkr_start}-${simu_average_2spkr_end}
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "stage 1: Training on simulated two-speaker data"
|
||||
mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir}
|
||||
mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir}/log
|
||||
INIT_FILE=${exp_dir}/exp/${simu_2spkr_model_dir}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
|
||||
train.py \
|
||||
--task_name diar \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor false \
|
||||
--input_size $input_size \
|
||||
--data_dir ${simu_feats_dir} \
|
||||
--train_set ${simu_train_dataset} \
|
||||
--valid_set ${simu_valid_dataset} \
|
||||
--data_file_names "feats_2spkr.scp" \
|
||||
--resume true \
|
||||
--output_dir ${exp_dir}/exp/${simu_2spkr_model_dir} \
|
||||
--config $simu_2spkr_diar_config \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${exp_dir}/exp/${simu_2spkr_model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
echo "averaging model parameters into ${exp_dir}/exp/$simu_2spkr_model_dir/$simu_2spkr_ave_id.pb"
|
||||
models=`eval echo ${exp_dir}/exp/${simu_2spkr_model_dir}/{$simu_average_2spkr_start..$simu_average_2spkr_end}epoch.pb`
|
||||
python local/model_averaging.py ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb $models
|
||||
fi
|
||||
|
||||
# Training on simulated all-speaker data
|
||||
world_size=$gpu_num
|
||||
simu_allspkr_ave_id=avg${simu_average_allspkr_start}-${simu_average_allspkr_end}
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "stage 2: Training on simulated all-speaker data"
|
||||
mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir}
|
||||
mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir}/log
|
||||
INIT_FILE=${exp_dir}/exp/${simu_allspkr_model_dir}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
|
||||
train.py \
|
||||
--task_name diar \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor false \
|
||||
--input_size $input_size \
|
||||
--data_dir ${simu_feats_dir} \
|
||||
--train_set ${simu_train_dataset} \
|
||||
--valid_set ${simu_valid_dataset} \
|
||||
--data_file_names "feats.scp" \
|
||||
--resume true \
|
||||
--init_param ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb \
|
||||
--output_dir ${exp_dir}/exp/${simu_allspkr_model_dir} \
|
||||
--config $simu_allspkr_diar_config \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
echo "averaging model parameters into ${exp_dir}/exp/$simu_allspkr_model_dir/$simu_allspkr_ave_id.pb"
|
||||
models=`eval echo ${exp_dir}/exp/${simu_allspkr_model_dir}/{$simu_average_allspkr_start..$simu_average_allspkr_end}epoch.pb`
|
||||
python local/model_averaging.py ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb $models
|
||||
fi
|
||||
|
||||
# Training on simulated all-speaker data with chunk_size 2000
|
||||
world_size=$gpu_num
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "stage 3: Training on simulated all-speaker data with chunk_size 2000"
|
||||
mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}
|
||||
mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log
|
||||
INIT_FILE=${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
|
||||
train.py \
|
||||
--task_name diar \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor false \
|
||||
--input_size $input_size \
|
||||
--data_dir ${simu_feats_dir_chunk2000} \
|
||||
--train_set ${simu_train_dataset} \
|
||||
--valid_set ${simu_valid_dataset} \
|
||||
--data_file_names "feats.scp" \
|
||||
--resume true \
|
||||
--init_param ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb \
|
||||
--output_dir ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir} \
|
||||
--config $simu_allspkr_chunk2000_diar_config \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
# Training on callhome all-speaker data with chunk_size 2000
|
||||
world_size=$gpu_num
|
||||
callhome_ave_id=avg${callhome_average_start}-${callhome_average_end}
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "stage 4: Training on callhome all-speaker data with chunk_size 2000"
|
||||
mkdir -p ${exp_dir}/exp/${callhome_model_dir}
|
||||
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/log
|
||||
INIT_FILE=${exp_dir}/exp/${callhome_model_dir}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
|
||||
train.py \
|
||||
--task_name diar \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor false \
|
||||
--input_size $input_size \
|
||||
--data_dir ${callhome_feats_dir_chunk2000} \
|
||||
--train_set ${callhome_train_dataset} \
|
||||
--valid_set ${callhome_valid_dataset} \
|
||||
--data_file_names "feats.scp" \
|
||||
--resume true \
|
||||
--init_param ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/1epoch.pb \
|
||||
--output_dir ${exp_dir}/exp/${callhome_model_dir} \
|
||||
--config $callhome_diar_config \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${exp_dir}/exp/${callhome_model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
echo "averaging model parameters into ${exp_dir}/exp/$callhome_model_dir/$callhome_ave_id.pb"
|
||||
models=`eval echo ${exp_dir}/exp/${callhome_model_dir}/{$callhome_average_start..$callhome_average_end}epoch.pb`
|
||||
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
|
||||
fi
|
||||
|
||||
# inference and compute DER
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "Inference"
|
||||
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
|
||||
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
|
||||
--config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
|
||||
--model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
|
||||
--output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
|
||||
--wav_scp_file $data_dir/eval/callhome2_spkall/wav.scp \
|
||||
1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
|
||||
md-eval.pl -c 0.25 \
|
||||
-r ${data_dir}/eval/${callhome_valid_dataset}/rttm \
|
||||
-s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
|
||||
fi
|
||||
2739
egs/callhome/sond/sond.yaml
Normal file
2739
egs/callhome/sond/sond.yaml
Normal file
File diff suppressed because it is too large
Load Diff
2739
egs/callhome/sond/sond_fbank.yaml
Normal file
2739
egs/callhome/sond/sond_fbank.yaml
Normal file
File diff suppressed because it is too large
Load Diff
97
egs/callhome/sond/unit_test.py
Normal file
97
egs/callhome/sond/unit_test.py
Normal file
@ -0,0 +1,97 @@
|
||||
from funasr.bin.diar_inference_launch import inference_launch
|
||||
import os
|
||||
|
||||
|
||||
def test_fbank_cpu_infer():
|
||||
diar_config_path = "sond_fbank.yaml"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
|
||||
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
|
||||
]
|
||||
pipeline = inference_launch(
|
||||
mode="sond",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
num_workers=0,
|
||||
log_level="INFO",
|
||||
)
|
||||
results = pipeline(data_path_and_name_and_type)
|
||||
print(results)
|
||||
|
||||
|
||||
def test_fbank_gpu_infer():
|
||||
diar_config_path = "sond_fbank.yaml"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
|
||||
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
|
||||
]
|
||||
pipeline = inference_launch(
|
||||
mode="sond",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
ngpu=1,
|
||||
num_workers=1,
|
||||
log_level="INFO",
|
||||
)
|
||||
results = pipeline(data_path_and_name_and_type)
|
||||
print(results)
|
||||
|
||||
|
||||
def test_wav_gpu_infer():
|
||||
diar_config_path = "config.yaml"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_wav.scp", "speech", "sound"),
|
||||
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
|
||||
]
|
||||
pipeline = inference_launch(
|
||||
mode="sond",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
ngpu=1,
|
||||
num_workers=1,
|
||||
log_level="WARNING",
|
||||
)
|
||||
results = pipeline(data_path_and_name_and_type)
|
||||
print(results)
|
||||
|
||||
|
||||
def test_without_profile_gpu_infer():
|
||||
diar_config_path = "config.yaml"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
raw_inputs = [[
|
||||
"data/unit_test/raw_inputs/record.wav",
|
||||
"data/unit_test/raw_inputs/spk1.wav",
|
||||
"data/unit_test/raw_inputs/spk2.wav",
|
||||
"data/unit_test/raw_inputs/spk3.wav",
|
||||
"data/unit_test/raw_inputs/spk4.wav"
|
||||
]]
|
||||
pipeline = inference_launch(
|
||||
mode="sond_demo",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
ngpu=1,
|
||||
num_workers=1,
|
||||
log_level="WARNING",
|
||||
param_dict={},
|
||||
)
|
||||
results = pipeline(raw_inputs=raw_inputs)
|
||||
print(results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
||||
test_fbank_cpu_infer()
|
||||
# test_fbank_gpu_infer()
|
||||
# test_wav_gpu_infer()
|
||||
# test_without_profile_gpu_infer()
|
||||
@ -86,6 +86,12 @@ def build_args(args, parser, extra_task_params):
|
||||
from funasr.build_utils.build_diar_model import class_choices_list
|
||||
for class_choices in class_choices_list:
|
||||
class_choices.add_arguments(task_parser)
|
||||
task_parser.add_argument(
|
||||
"--input_size",
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help="The number of input dimension of the feature",
|
||||
)
|
||||
|
||||
elif args.task_name == "sv":
|
||||
from funasr.build_utils.build_sv_model import class_choices_list
|
||||
|
||||
@ -4,8 +4,21 @@ from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFac
|
||||
|
||||
def build_dataloader(args):
|
||||
if args.dataset_type == "small":
|
||||
train_iter_factory = SequenceIterFactory(args, mode="train")
|
||||
valid_iter_factory = SequenceIterFactory(args, mode="valid")
|
||||
if args.task_name == "diar" and args.model == "eend_ola":
|
||||
from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
|
||||
train_iter_factory = EENDOLADataLoader(
|
||||
data_file=args.train_data_path_and_name_and_type[0][0],
|
||||
batch_size=args.dataset_conf["batch_conf"]["batch_size"],
|
||||
num_workers=args.dataset_conf["num_workers"],
|
||||
shuffle=True)
|
||||
valid_iter_factory = EENDOLADataLoader(
|
||||
data_file=args.valid_data_path_and_name_and_type[0][0],
|
||||
batch_size=args.dataset_conf["batch_conf"]["batch_size"],
|
||||
num_workers=0,
|
||||
shuffle=False)
|
||||
else:
|
||||
train_iter_factory = SequenceIterFactory(args, mode="train")
|
||||
valid_iter_factory = SequenceIterFactory(args, mode="valid")
|
||||
elif args.dataset_type == "large":
|
||||
train_iter_factory = LargeDataLoader(args, mode="train")
|
||||
valid_iter_factory = LargeDataLoader(args, mode="valid")
|
||||
|
||||
@ -192,18 +192,22 @@ class_choices_list = [
|
||||
|
||||
def build_diar_model(args):
|
||||
# token_list
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
if args.token_list is not None:
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size}")
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size}")
|
||||
token_list = None
|
||||
vocab_size = None
|
||||
|
||||
# frontend
|
||||
if args.input_size is None:
|
||||
@ -212,16 +216,14 @@ def build_diar_model(args):
|
||||
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||
else:
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
||||
encoder = encoder_class(**args.encoder_conf)
|
||||
|
||||
if args.model == "sond":
|
||||
# data augmentation for spectrogram
|
||||
@ -294,7 +296,7 @@ def build_diar_model(args):
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
elif args.model_name == "eend_ola":
|
||||
elif args.model == "eend_ola":
|
||||
# encoder-decoder attractor
|
||||
encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
|
||||
encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
|
||||
|
||||
@ -57,7 +57,7 @@ class SequenceIterFactory(AbsIterFactory):
|
||||
data_path_and_name_and_type,
|
||||
preprocess=preprocess_fn,
|
||||
dest_sample_rate=dest_sample_rate,
|
||||
speed_perturb=args.speed_perturb if mode=="train" else None,
|
||||
speed_perturb=args.speed_perturb if mode == "train" else None,
|
||||
)
|
||||
|
||||
# sampler
|
||||
@ -84,7 +84,7 @@ class SequenceIterFactory(AbsIterFactory):
|
||||
args.max_update = len(bs_list) * args.max_epoch
|
||||
logging.info("Max update: {}".format(args.max_update))
|
||||
|
||||
if args.distributed and mode=="train":
|
||||
if args.distributed and mode == "train":
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
for batch in batches:
|
||||
|
||||
@ -1,21 +1,20 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import Tuple
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
||||
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
||||
from funasr.modules.eend_ola.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
|
||||
from funasr.modules.eend_ola.utils.power import create_powerlabel
|
||||
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
pass
|
||||
@ -33,12 +32,35 @@ def pad_attractor(att, max_n_speakers):
|
||||
return att
|
||||
|
||||
|
||||
def pad_labels(ts, out_size):
|
||||
for i, t in enumerate(ts):
|
||||
if t.shape[1] < out_size:
|
||||
ts[i] = F.pad(
|
||||
t,
|
||||
(0, out_size - t.shape[1], 0, 0),
|
||||
mode='constant',
|
||||
value=0.
|
||||
)
|
||||
return ts
|
||||
|
||||
|
||||
def pad_results(ys, out_size):
|
||||
ys_padded = []
|
||||
for i, y in enumerate(ys):
|
||||
if y.shape[1] < out_size:
|
||||
ys_padded.append(
|
||||
torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
|
||||
else:
|
||||
ys_padded.append(y)
|
||||
return ys_padded
|
||||
|
||||
|
||||
class DiarEENDOLAModel(FunASRModel):
|
||||
"""EEND-OLA diarization model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: WavFrontendMel23,
|
||||
frontend: Optional[WavFrontendMel23],
|
||||
encoder: EENDOLATransformerEncoder,
|
||||
encoder_decoder_attractor: EncoderDecoderAttractor,
|
||||
n_units: int = 256,
|
||||
@ -47,11 +69,10 @@ class DiarEENDOLAModel(FunASRModel):
|
||||
mapping_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.frontend = frontend
|
||||
self.enc = encoder
|
||||
self.eda = encoder_decoder_attractor
|
||||
self.encoder_decoder_attractor = encoder_decoder_attractor
|
||||
self.attractor_loss_weight = attractor_loss_weight
|
||||
self.max_n_speaker = max_n_speaker
|
||||
if mapping_dict is None:
|
||||
@ -74,7 +95,8 @@ class DiarEENDOLAModel(FunASRModel):
|
||||
def forward_post_net(self, logits, ilens):
|
||||
maxlen = torch.max(ilens).to(torch.int).item()
|
||||
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
|
||||
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
|
||||
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
|
||||
enforce_sorted=False)
|
||||
outputs, (_, _) = self.postnet(logits)
|
||||
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
|
||||
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
|
||||
@ -83,95 +105,45 @@ class DiarEENDOLAModel(FunASRModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
speech: List[torch.Tensor],
|
||||
speaker_labels: List[torch.Tensor],
|
||||
orders: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
batch_size = speech.shape[0]
|
||||
assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels))
|
||||
speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
|
||||
speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64)
|
||||
batch_size = len(speech)
|
||||
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
# Encoder
|
||||
encoder_out = self.forward_encoder(speech, speech_lengths)
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
# Encoder-decoder attractor
|
||||
attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
|
||||
speaker_labels_lengths)
|
||||
speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
|
||||
|
||||
# pit loss
|
||||
pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
|
||||
pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
|
||||
|
||||
# pse loss
|
||||
with torch.no_grad():
|
||||
power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
|
||||
to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
|
||||
pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
|
||||
pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
|
||||
pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
|
||||
pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
|
||||
|
||||
loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
stats = dict()
|
||||
|
||||
# 1. CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
intermediate_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
stats["acc"] = acc_att
|
||||
stats["cer"] = cer_att
|
||||
stats["wer"] = wer_att
|
||||
stats["pse_loss"] = pse_loss.detach()
|
||||
stats["pit_loss"] = pit_loss.detach()
|
||||
stats["attractor_loss"] = attractor_loss.detach()
|
||||
stats["batch_size"] = batch_size
|
||||
|
||||
# Collect total loss stats
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
@ -182,21 +154,20 @@ class DiarEENDOLAModel(FunASRModel):
|
||||
|
||||
def estimate_sequential(self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
n_speakers: int = None,
|
||||
shuffle: bool = True,
|
||||
threshold: float = 0.5,
|
||||
**kwargs):
|
||||
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
|
||||
speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
|
||||
emb = self.forward_encoder(speech, speech_lengths)
|
||||
if shuffle:
|
||||
orders = [np.arange(e.shape[0]) for e in emb]
|
||||
for order in orders:
|
||||
np.random.shuffle(order)
|
||||
attractors, probs = self.eda.estimate(
|
||||
attractors, probs = self.encoder_decoder_attractor.estimate(
|
||||
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
|
||||
else:
|
||||
attractors, probs = self.eda.estimate(emb)
|
||||
attractors, probs = self.encoder_decoder_attractor.estimate(emb)
|
||||
attractors_active = []
|
||||
for p, att, e in zip(probs, attractors, emb):
|
||||
if n_speakers and n_speakers >= 0:
|
||||
|
||||
57
funasr/modules/eend_ola/eend_ola_dataloader.py
Normal file
57
funasr/modules/eend_ola/eend_ola_dataloader.py
Normal file
@ -0,0 +1,57 @@
|
||||
import logging
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def custom_collate(batch):
|
||||
keys, speech, speaker_labels, orders = zip(*batch)
|
||||
speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
|
||||
speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
|
||||
orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
|
||||
batch = dict(speech=speech,
|
||||
speaker_labels=speaker_labels,
|
||||
orders=orders)
|
||||
|
||||
return keys, batch
|
||||
|
||||
|
||||
class EENDOLADataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_file,
|
||||
):
|
||||
self.data_file = data_file
|
||||
with open(data_file) as f:
|
||||
lines = f.readlines()
|
||||
self.samples = [line.strip().split() for line in lines]
|
||||
logging.info("total samples: {}".format(len(self.samples)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
key, speech_path, speaker_label_path = self.samples[idx]
|
||||
speech = kaldiio.load_mat(speech_path)
|
||||
speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1)
|
||||
|
||||
order = np.arange(speech.shape[0])
|
||||
np.random.shuffle(order)
|
||||
|
||||
return key, speech, speaker_label, order
|
||||
|
||||
|
||||
class EENDOLADataLoader():
|
||||
def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
|
||||
dataset = EENDOLADataset(data_file)
|
||||
self.data_loader = DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=custom_collate,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers)
|
||||
|
||||
def build_iter(self, epoch):
|
||||
return self.data_loader
|
||||
@ -91,6 +91,7 @@ class EENDOLATransformerEncoder(nn.Module):
|
||||
dropout_rate: float = 0.1,
|
||||
use_pos_emb: bool = False):
|
||||
super(EENDOLATransformerEncoder, self).__init__()
|
||||
self.linear_in = nn.Linear(idim, n_units)
|
||||
self.lnorm_in = nn.LayerNorm(n_units)
|
||||
self.n_layers = n_layers
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
@ -104,25 +105,10 @@ class EENDOLATransformerEncoder(nn.Module):
|
||||
setattr(self, '{}{:d}'.format("ff_", i),
|
||||
PositionwiseFeedForward(n_units, e_units, dropout_rate))
|
||||
self.lnorm_out = nn.LayerNorm(n_units)
|
||||
if use_pos_emb:
|
||||
self.pos_enc = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, n_units),
|
||||
torch.nn.LayerNorm(n_units),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
PositionalEncoding(n_units, dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.linear_in = nn.Linear(idim, n_units)
|
||||
self.pos_enc = None
|
||||
|
||||
def __call__(self, x, x_mask=None):
|
||||
BT_size = x.shape[0] * x.shape[1]
|
||||
if self.pos_enc is not None:
|
||||
e = self.pos_enc(x)
|
||||
e = e.view(BT_size, -1)
|
||||
else:
|
||||
e = self.linear_in(x.reshape(BT_size, -1))
|
||||
e = self.linear_in(x.reshape(BT_size, -1))
|
||||
for i in range(self.n_layers):
|
||||
e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
|
||||
s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
|
||||
@ -130,4 +116,4 @@ class EENDOLATransformerEncoder(nn.Module):
|
||||
e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
|
||||
s = getattr(self, '{}{:d}'.format("ff_", i))(e)
|
||||
e = e + self.dropout(s)
|
||||
return self.lnorm_out(e)
|
||||
return self.lnorm_out(e)
|
||||
286
funasr/modules/eend_ola/utils/feature.py
Normal file
286
funasr/modules/eend_ola/utils/feature.py
Normal file
@ -0,0 +1,286 @@
|
||||
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# This module is for computing audio features
|
||||
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
|
||||
def get_input_dim(
|
||||
frame_size,
|
||||
context_size,
|
||||
transform_type,
|
||||
):
|
||||
if transform_type.startswith('logmel23'):
|
||||
frame_size = 23
|
||||
elif transform_type.startswith('logmel'):
|
||||
frame_size = 40
|
||||
else:
|
||||
fft_size = 1 << (frame_size - 1).bit_length()
|
||||
frame_size = int(fft_size / 2) + 1
|
||||
input_dim = (2 * context_size + 1) * frame_size
|
||||
return input_dim
|
||||
|
||||
|
||||
def transform(
|
||||
Y,
|
||||
transform_type=None,
|
||||
dtype=np.float32):
|
||||
""" Transform STFT feature
|
||||
|
||||
Args:
|
||||
Y: STFT
|
||||
(n_frames, n_bins)-shaped np.complex array
|
||||
transform_type:
|
||||
None, "log"
|
||||
dtype: output data type
|
||||
np.float32 is expected
|
||||
Returns:
|
||||
Y (numpy.array): transformed feature
|
||||
"""
|
||||
Y = np.abs(Y)
|
||||
if not transform_type:
|
||||
pass
|
||||
elif transform_type == 'log':
|
||||
Y = np.log(np.maximum(Y, 1e-10))
|
||||
elif transform_type == 'logmel':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 16000
|
||||
n_mels = 40
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
elif transform_type == 'logmel23':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
elif transform_type == 'logmel23_mn':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
mean = np.mean(Y, axis=0)
|
||||
Y = Y - mean
|
||||
elif transform_type == 'logmel23_swn':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
# b = np.ones(300)/300
|
||||
# mean = scipy.signal.convolve2d(Y, b[:, None], mode='same')
|
||||
|
||||
# simple 2-means based threshoding for mean calculation
|
||||
powers = np.sum(Y, axis=1)
|
||||
th = (np.max(powers) + np.min(powers)) / 2.0
|
||||
for i in range(10):
|
||||
th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2
|
||||
mean = np.mean(Y[powers > th, :], axis=0)
|
||||
Y = Y - mean
|
||||
elif transform_type == 'logmel23_mvn':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
mean = np.mean(Y, axis=0)
|
||||
Y = Y - mean
|
||||
std = np.maximum(np.std(Y, axis=0), 1e-10)
|
||||
Y = Y / std
|
||||
else:
|
||||
raise ValueError('Unknown transform_type: %s' % transform_type)
|
||||
return Y.astype(dtype)
|
||||
|
||||
|
||||
def subsample(Y, T, subsampling=1):
|
||||
""" Frame subsampling
|
||||
"""
|
||||
Y_ss = Y[::subsampling]
|
||||
T_ss = T[::subsampling]
|
||||
return Y_ss, T_ss
|
||||
|
||||
|
||||
def splice(Y, context_size=0):
|
||||
""" Frame splicing
|
||||
|
||||
Args:
|
||||
Y: feature
|
||||
(n_frames, n_featdim)-shaped numpy array
|
||||
context_size:
|
||||
number of frames concatenated on left-side
|
||||
if context_size = 5, 11 frames are concatenated.
|
||||
|
||||
Returns:
|
||||
Y_spliced: spliced feature
|
||||
(n_frames, n_featdim * (2 * context_size + 1))-shaped
|
||||
"""
|
||||
Y_pad = np.pad(
|
||||
Y,
|
||||
[(context_size, context_size), (0, 0)],
|
||||
'constant')
|
||||
Y_spliced = np.lib.stride_tricks.as_strided(
|
||||
np.ascontiguousarray(Y_pad),
|
||||
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
|
||||
(Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
|
||||
return Y_spliced
|
||||
|
||||
|
||||
def stft(
|
||||
data,
|
||||
frame_size=1024,
|
||||
frame_shift=256):
|
||||
""" Compute STFT features
|
||||
|
||||
Args:
|
||||
data: audio signal
|
||||
(n_samples,)-shaped np.float32 array
|
||||
frame_size: number of samples in a frame (must be a power of two)
|
||||
frame_shift: number of samples between frames
|
||||
|
||||
Returns:
|
||||
stft: STFT frames
|
||||
(n_frames, n_bins)-shaped np.complex64 array
|
||||
"""
|
||||
# round up to nearest power of 2
|
||||
fft_size = 1 << (frame_size - 1).bit_length()
|
||||
# HACK: The last frame is ommited
|
||||
# as librosa.stft produces such an excessive frame
|
||||
if len(data) % frame_shift == 0:
|
||||
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||
hop_length=frame_shift).T[:-1]
|
||||
else:
|
||||
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||
hop_length=frame_shift).T
|
||||
|
||||
|
||||
def _count_frames(data_len, size, shift):
|
||||
# HACK: Assuming librosa.stft(..., center=True)
|
||||
n_frames = 1 + int(data_len / shift)
|
||||
if data_len % shift == 0:
|
||||
n_frames = n_frames - 1
|
||||
return n_frames
|
||||
|
||||
|
||||
def get_frame_labels(
|
||||
kaldi_obj,
|
||||
rec,
|
||||
start=0,
|
||||
end=None,
|
||||
frame_size=1024,
|
||||
frame_shift=256,
|
||||
n_speakers=None):
|
||||
""" Get frame-aligned labels of given recording
|
||||
Args:
|
||||
kaldi_obj (KaldiData)
|
||||
rec (str): recording id
|
||||
start (int): start frame index
|
||||
end (int): end frame index
|
||||
None means the last frame of recording
|
||||
frame_size (int): number of frames in a frame
|
||||
frame_shift (int): number of shift samples
|
||||
n_speakers (int): number of speakers
|
||||
if None, the value is given from data
|
||||
Returns:
|
||||
T: label
|
||||
(n_frames, n_speakers)-shaped np.int32 array
|
||||
"""
|
||||
filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
|
||||
speakers = np.unique(
|
||||
[kaldi_obj.utt2spk[seg['utt']] for seg
|
||||
in filtered_segments]).tolist()
|
||||
if n_speakers is None:
|
||||
n_speakers = len(speakers)
|
||||
es = end * frame_shift if end is not None else None
|
||||
data, rate = kaldi_obj.load_wav(rec, start * frame_shift, es)
|
||||
n_frames = _count_frames(len(data), frame_size, frame_shift)
|
||||
T = np.zeros((n_frames, n_speakers), dtype=np.int32)
|
||||
if end is None:
|
||||
end = n_frames
|
||||
|
||||
for seg in filtered_segments:
|
||||
speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
|
||||
start_frame = np.rint(
|
||||
seg['st'] * rate / frame_shift).astype(int)
|
||||
end_frame = np.rint(
|
||||
seg['et'] * rate / frame_shift).astype(int)
|
||||
rel_start = rel_end = None
|
||||
if start <= start_frame and start_frame < end:
|
||||
rel_start = start_frame - start
|
||||
if start < end_frame and end_frame <= end:
|
||||
rel_end = end_frame - start
|
||||
if rel_start is not None or rel_end is not None:
|
||||
T[rel_start:rel_end, speaker_index] = 1
|
||||
return T
|
||||
|
||||
|
||||
def get_labeledSTFT(
|
||||
kaldi_obj,
|
||||
rec, start, end, frame_size, frame_shift,
|
||||
n_speakers=None,
|
||||
use_speaker_id=False):
|
||||
""" Extracts STFT and corresponding labels
|
||||
|
||||
Extracts STFT and corresponding diarization labels for
|
||||
given recording id and start/end times
|
||||
|
||||
Args:
|
||||
kaldi_obj (KaldiData)
|
||||
rec (str): recording id
|
||||
start (int): start frame index
|
||||
end (int): end frame index
|
||||
frame_size (int): number of samples in a frame
|
||||
frame_shift (int): number of shift samples
|
||||
n_speakers (int): number of speakers
|
||||
if None, the value is given from data
|
||||
Returns:
|
||||
Y: STFT
|
||||
(n_frames, n_bins)-shaped np.complex64 array,
|
||||
T: label
|
||||
(n_frmaes, n_speakers)-shaped np.int32 array.
|
||||
"""
|
||||
data, rate = kaldi_obj.load_wav(
|
||||
rec, start * frame_shift, end * frame_shift)
|
||||
Y = stft(data, frame_size, frame_shift)
|
||||
filtered_segments = kaldi_obj.segments[rec]
|
||||
# filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
|
||||
speakers = np.unique(
|
||||
[kaldi_obj.utt2spk[seg['utt']] for seg
|
||||
in filtered_segments]).tolist()
|
||||
if n_speakers is None:
|
||||
n_speakers = len(speakers)
|
||||
T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32)
|
||||
|
||||
if use_speaker_id:
|
||||
all_speakers = sorted(kaldi_obj.spk2utt.keys())
|
||||
S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32)
|
||||
|
||||
for seg in filtered_segments:
|
||||
speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
|
||||
if use_speaker_id:
|
||||
all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']])
|
||||
start_frame = np.rint(
|
||||
seg['st'] * rate / frame_shift).astype(int)
|
||||
end_frame = np.rint(
|
||||
seg['et'] * rate / frame_shift).astype(int)
|
||||
rel_start = rel_end = None
|
||||
if start <= start_frame and start_frame < end:
|
||||
rel_start = start_frame - start
|
||||
if start < end_frame and end_frame <= end:
|
||||
rel_end = end_frame - start
|
||||
if rel_start is not None or rel_end is not None:
|
||||
T[rel_start:rel_end, speaker_index] = 1
|
||||
if use_speaker_id:
|
||||
S[rel_start:rel_end, all_speaker_index] = 1
|
||||
|
||||
if use_speaker_id:
|
||||
return Y, T, S
|
||||
else:
|
||||
return Y, T
|
||||
162
funasr/modules/eend_ola/utils/kaldi_data.py
Normal file
162
funasr/modules/eend_ola/utils/kaldi_data.py
Normal file
@ -0,0 +1,162 @@
|
||||
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# This library provides utilities for kaldi-style data directory.
|
||||
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import soundfile as sf
|
||||
import io
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
def load_segments(segments_file):
|
||||
""" load segments file as array """
|
||||
if not os.path.exists(segments_file):
|
||||
return None
|
||||
return np.loadtxt(
|
||||
segments_file,
|
||||
dtype=[('utt', 'object'),
|
||||
('rec', 'object'),
|
||||
('st', 'f'),
|
||||
('et', 'f')],
|
||||
ndmin=1)
|
||||
|
||||
|
||||
def load_segments_hash(segments_file):
|
||||
ret = {}
|
||||
if not os.path.exists(segments_file):
|
||||
return None
|
||||
for line in open(segments_file):
|
||||
utt, rec, st, et = line.strip().split()
|
||||
ret[utt] = (rec, float(st), float(et))
|
||||
return ret
|
||||
|
||||
|
||||
def load_segments_rechash(segments_file):
|
||||
ret = {}
|
||||
if not os.path.exists(segments_file):
|
||||
return None
|
||||
for line in open(segments_file):
|
||||
utt, rec, st, et = line.strip().split()
|
||||
if rec not in ret:
|
||||
ret[rec] = []
|
||||
ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
|
||||
return ret
|
||||
|
||||
|
||||
def load_wav_scp(wav_scp_file):
|
||||
""" return dictionary { rec: wav_rxfilename } """
|
||||
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
|
||||
return {x[0]: x[1] for x in lines}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def load_wav(wav_rxfilename, start=0, end=None):
|
||||
""" This function reads audio file and return data in numpy.float32 array.
|
||||
"lru_cache" holds recently loaded audio so that can be called
|
||||
many times on the same audio file.
|
||||
OPTIMIZE: controls lru_cache size for random access,
|
||||
considering memory size
|
||||
"""
|
||||
if wav_rxfilename.endswith('|'):
|
||||
# input piped command
|
||||
p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
|
||||
stdout=subprocess.PIPE)
|
||||
data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
|
||||
dtype='float32')
|
||||
# cannot seek
|
||||
data = data[start:end]
|
||||
elif wav_rxfilename == '-':
|
||||
# stdin
|
||||
data, samplerate = sf.read(sys.stdin, dtype='float32')
|
||||
# cannot seek
|
||||
data = data[start:end]
|
||||
else:
|
||||
# normal wav file
|
||||
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
|
||||
return data, samplerate
|
||||
|
||||
|
||||
def load_utt2spk(utt2spk_file):
|
||||
""" returns dictionary { uttid: spkid } """
|
||||
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
|
||||
return {x[0]: x[1] for x in lines}
|
||||
|
||||
|
||||
def load_spk2utt(spk2utt_file):
|
||||
""" returns dictionary { spkid: list of uttids } """
|
||||
if not os.path.exists(spk2utt_file):
|
||||
return None
|
||||
lines = [line.strip().split() for line in open(spk2utt_file)]
|
||||
return {x[0]: x[1:] for x in lines}
|
||||
|
||||
|
||||
def load_reco2dur(reco2dur_file):
|
||||
""" returns dictionary { recid: duration } """
|
||||
if not os.path.exists(reco2dur_file):
|
||||
return None
|
||||
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
|
||||
return {x[0]: float(x[1]) for x in lines}
|
||||
|
||||
|
||||
def process_wav(wav_rxfilename, process):
|
||||
""" This function returns preprocessed wav_rxfilename
|
||||
Args:
|
||||
wav_rxfilename: input
|
||||
process: command which can be connected via pipe,
|
||||
use stdin and stdout
|
||||
Returns:
|
||||
wav_rxfilename: output piped command
|
||||
"""
|
||||
if wav_rxfilename.endswith('|'):
|
||||
# input piped command
|
||||
return wav_rxfilename + process + "|"
|
||||
else:
|
||||
# stdin "-" or normal file
|
||||
return "cat {} | {} |".format(wav_rxfilename, process)
|
||||
|
||||
|
||||
def extract_segments(wavs, segments=None):
|
||||
""" This function returns generator of segmented audio as
|
||||
(utterance id, numpy.float32 array)
|
||||
TODO?: sampling rate is not converted.
|
||||
"""
|
||||
if segments is not None:
|
||||
# segments should be sorted by rec-id
|
||||
for seg in segments:
|
||||
wav = wavs[seg['rec']]
|
||||
data, samplerate = load_wav(wav)
|
||||
st_sample = np.rint(seg['st'] * samplerate).astype(int)
|
||||
et_sample = np.rint(seg['et'] * samplerate).astype(int)
|
||||
yield seg['utt'], data[st_sample:et_sample]
|
||||
else:
|
||||
# segments file not found,
|
||||
# wav.scp is used as segmented audio list
|
||||
for rec in wavs:
|
||||
data, samplerate = load_wav(wavs[rec])
|
||||
yield rec, data
|
||||
|
||||
|
||||
class KaldiData:
|
||||
def __init__(self, data_dir):
|
||||
self.data_dir = data_dir
|
||||
self.segments = load_segments_rechash(
|
||||
os.path.join(self.data_dir, 'segments'))
|
||||
self.utt2spk = load_utt2spk(
|
||||
os.path.join(self.data_dir, 'utt2spk'))
|
||||
self.wavs = load_wav_scp(
|
||||
os.path.join(self.data_dir, 'wav.scp'))
|
||||
self.reco2dur = load_reco2dur(
|
||||
os.path.join(self.data_dir, 'reco2dur'))
|
||||
self.spk2utt = load_spk2utt(
|
||||
os.path.join(self.data_dir, 'spk2utt'))
|
||||
|
||||
def load_wav(self, recid, start=0, end=None):
|
||||
data, rate = load_wav(
|
||||
self.wavs[recid], start, end)
|
||||
return data, rate
|
||||
@ -1,11 +1,10 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from itertools import permutations
|
||||
from torch import nn
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
def standard_loss(ys, ts, label_delay=0):
|
||||
def standard_loss(ys, ts):
|
||||
losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
|
||||
loss = torch.sum(torch.stack(losses))
|
||||
n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
|
||||
@ -13,55 +12,29 @@ def standard_loss(ys, ts, label_delay=0):
|
||||
return loss
|
||||
|
||||
|
||||
def batch_pit_n_speaker_loss(ys, ts, n_speakers_list):
|
||||
max_n_speakers = ts[0].shape[1]
|
||||
olens = [y.shape[0] for y in ys]
|
||||
ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-1)
|
||||
ys_mask = [torch.ones(olen).to(ys.device) for olen in olens]
|
||||
ys_mask = torch.nn.utils.rnn.pad_sequence(ys_mask, batch_first=True, padding_value=0).unsqueeze(-1)
|
||||
def fast_batch_pit_n_speaker_loss(ys, ts):
|
||||
with torch.no_grad():
|
||||
bs = len(ys)
|
||||
indices = []
|
||||
for b in range(bs):
|
||||
y = ys[b].transpose(0, 1)
|
||||
t = ts[b].transpose(0, 1)
|
||||
C, _ = t.shape
|
||||
y = y[:, None, :].repeat(1, C, 1)
|
||||
t = t[None, :, :].repeat(C, 1, 1)
|
||||
bce_loss = F.binary_cross_entropy(torch.sigmoid(y), t, reduction="none").mean(-1)
|
||||
C = bce_loss.cpu()
|
||||
indices.append(linear_sum_assignment(C))
|
||||
labels_perm = [t[:, idx[1]] for t, idx in zip(ts, indices)]
|
||||
|
||||
losses = []
|
||||
for shift in range(max_n_speakers):
|
||||
ts_roll = [torch.roll(t, -shift, dims=1) for t in ts]
|
||||
ts_roll = nn.utils.rnn.pad_sequence(ts_roll, batch_first=True, padding_value=-1)
|
||||
loss = F.binary_cross_entropy(torch.sigmoid(ys), ts_roll, reduction='none')
|
||||
if ys_mask is not None:
|
||||
loss = loss * ys_mask
|
||||
loss = torch.sum(loss, dim=1)
|
||||
losses.append(loss)
|
||||
losses = torch.stack(losses, dim=2)
|
||||
return labels_perm
|
||||
|
||||
perms = np.array(list(permutations(range(max_n_speakers)))).astype(np.float32)
|
||||
perms = torch.from_numpy(perms).to(losses.device)
|
||||
y_ind = torch.arange(max_n_speakers, dtype=torch.float32, device=losses.device)
|
||||
t_inds = torch.fmod(perms - y_ind, max_n_speakers).to(torch.long)
|
||||
|
||||
losses_perm = []
|
||||
for t_ind in t_inds:
|
||||
losses_perm.append(
|
||||
torch.mean(losses[:, y_ind.to(torch.long), t_ind], dim=1))
|
||||
losses_perm = torch.stack(losses_perm, dim=1)
|
||||
|
||||
def select_perm_indices(num, max_num):
|
||||
perms = list(permutations(range(max_num)))
|
||||
sub_perms = list(permutations(range(num)))
|
||||
return [
|
||||
[x[:num] for x in perms].index(perm)
|
||||
for perm in sub_perms]
|
||||
|
||||
masks = torch.full_like(losses_perm, device=losses.device, fill_value=float('inf'))
|
||||
for i, t in enumerate(ts):
|
||||
n_speakers = n_speakers_list[i]
|
||||
indices = select_perm_indices(n_speakers, max_n_speakers)
|
||||
masks[i, indices] = 0
|
||||
losses_perm += masks
|
||||
|
||||
min_loss = torch.sum(torch.min(losses_perm, dim=1)[0])
|
||||
n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(losses.device)
|
||||
min_loss = min_loss / n_frames
|
||||
|
||||
min_indices = torch.argmin(losses_perm, dim=1)
|
||||
labels_perm = [t[:, perms[idx].to(torch.long)] for t, idx in zip(ts, min_indices)]
|
||||
labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)]
|
||||
|
||||
return min_loss, labels_perm
|
||||
def cal_power_loss(logits, power_ts):
|
||||
losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
|
||||
zip(logits, power_ts)]
|
||||
loss = torch.sum(torch.stack(losses))
|
||||
n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
|
||||
power_ts[0].device)
|
||||
loss = loss / n_frames
|
||||
return loss
|
||||
|
||||
@ -196,12 +196,16 @@ def generate_data_list(args, data_dir, dataset, nj=64):
|
||||
|
||||
def prepare_data(args, distributed_option):
|
||||
distributed = distributed_option.distributed
|
||||
data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
|
||||
data_types = args.dataset_conf.get("data_types", "sound,text").split(",")
|
||||
file_names = args.data_file_names.split(",")
|
||||
batch_type = args.dataset_conf["batch_conf"]["batch_type"]
|
||||
if not distributed or distributed_option.dist_rank == 0:
|
||||
if hasattr(args, "filter_input") and args.filter_input:
|
||||
filter_wav_text(args.data_dir, args.train_set)
|
||||
filter_wav_text(args.data_dir, args.valid_set)
|
||||
|
||||
if args.dataset_type == "small":
|
||||
if args.dataset_type == "small" and batch_type != "unsorted":
|
||||
calc_shape(args, args.train_set)
|
||||
calc_shape(args, args.valid_set)
|
||||
|
||||
@ -209,9 +213,6 @@ def prepare_data(args, distributed_option):
|
||||
generate_data_list(args, args.data_dir, args.train_set)
|
||||
generate_data_list(args, args.data_dir, args.valid_set)
|
||||
|
||||
data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
|
||||
data_types = args.dataset_conf.get("data_types", "sound,text").split(",")
|
||||
file_names = args.data_file_names.split(",")
|
||||
print("data_names: {}, data_types: {}, file_names: {}".format(data_names, data_types, file_names))
|
||||
assert len(data_names) == len(data_types) == len(file_names)
|
||||
if args.dataset_type == "small":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user