update repo

This commit is contained in:
嘉渊 2023-05-12 14:48:21 +08:00
parent 624cad13ba
commit 81123acf88
13 changed files with 0 additions and 1349 deletions

View File

@ -1,121 +0,0 @@
model: sond
model_conf:
lsm_weight: 0.0
length_normalized_loss: true
max_spk_num: 16
# speech encoder
encoder: ecapa_tdnn
encoder_conf:
# pass by model, equal to feature dim
# input_size: 80
pool_size: 20
stride: 1
speaker_encoder: conv
speaker_encoder_conf:
input_units: 256
num_layers: 3
num_units: 256
kernel_size: 1
dropout_rate: 0.0
position_encoder: null
out_units: 256
out_norm: false
auxiliary_states: false
tf2torch_tensor_name_prefix_torch: speaker_encoder
tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
ci_scorer: dot
ci_scorer_conf: {}
cd_scorer: san
cd_scorer_conf:
input_size: 512
output_size: 512
out_units: 1
attention_heads: 4
linear_units: 1024
num_blocks: 4
dropout_rate: 0.0
positional_dropout_rate: 0.0
attention_dropout_rate: 0.0
# use string "null" to remove input layer
input_layer: "null"
pos_enc_class: null
normalize_before: true
tf2torch_tensor_name_prefix_torch: cd_scorer
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
# post net
decoder: fsmn
decoder_conf:
in_units: 32
out_units: 2517
filter_size: 31
fsmn_num_layers: 6
dnn_num_layers: 1
num_memory_units: 512
ffn_inner_dim: 512
dropout_rate: 0.0
tf2torch_tensor_name_prefix_torch: decoder
tf2torch_tensor_name_prefix_tf: EAND/post_net
frontend: wav_frontend
frontend_conf:
fs: 16000
window: povey
n_mels: 80
frame_length: 25
frame_shift: 10
filter_length_min: -1
filter_length_max: -1
lfr_m: 1
lfr_n: 1
dither: 0.0
snip_edges: false
# minibatch related
batch_type: length
# 16s * 16k * 16 samples
batch_bins: 4096000
num_workers: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 50
val_scheduler_criterion:
- valid
- acc
best_model_criterion:
- - valid
- der
- min
- - valid
- forward_steps
- max
keep_nbest_models: 10
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 10000
# without spec aug
specaug: null
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2
log_interval: 50
# without normalize
normalize: None

View File

@ -1,171 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="6,7"
gpu_num=2
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="." #feature output dictionary
exp_dir="."
lang=zh
dumpdir=dump/raw
feats_type=raw
token_type=char
scp=wav.scp
type=kaldi_ark
stage=3
stop_stage=4
# feature configuration
feats_dim=
sample_frequency=16000
nj=32
speed_perturb=
# exp tag
tag="exp1"
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train
valid_set=dev
test_sets="dev test"
asr_config=conf/train_asr_conformer.yaml
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
inference_nj=$[${ngpu}*${njob}]
_ngpu=1
else
inference_nj=$njob
_ngpu=0
fi
feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
# Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
asr_train.py \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
--input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
} &
done
wait
fi
# Testing Stage
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "stage 4: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
_dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
exit 0
fi
mkdir -p "${_logdir}"
_data="${feats_dir}/${dumpdir}/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
_opts=
if [ -n "${inference_config}" ]; then
_opts+="--config ${inference_config} "
fi
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
--output_dir "${_logdir}"/output.JOB \
--mode asr \
${_opts}
for f in token token_int score text; do
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/1best_recog/${f}"
done | sort -k1 >"${_dir}/${f}"
fi
done
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
python utils/proce_text.py ${_data}/text ${_data}/text.proc
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
fi

View File

@ -1,5 +0,0 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH

View File

@ -1,45 +0,0 @@
import logging
import numpy as np
import soundfile
import kaldiio
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import argparse
from collections import OrderedDict
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser: argparse.ArgumentParser):
parser.add_argument("--input_scp", type=str, required=True)
parser.add_argument("--out_path")
args = parser.parse_args()
if not os.path.exists(os.path.dirname(args.out_path)):
os.makedirs(os.path.dirname(args.out_path))
task_list = load_scp_as_list(args.input_scp)
return task_list, None, args
def post(self, result_list, args):
fd = open(args.out_path, "wt", encoding="utf-8")
for results in result_list:
for uttid, shape in results:
fd.write("{} {}\n".format(uttid, ",".join(shape)))
fd.close()
def process(task_args):
task_idx, task_list, _, args = task_args
rst = []
for uttid, file_path in task_list:
data = kaldiio.load_mat(file_path)
shape = [str(x) for x in data.shape]
rst.append((uttid, shape))
return rst
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -1,140 +0,0 @@
import logging
import numpy as np
import soundfile
import kaldiio
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import argparse
from collections import OrderedDict
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser: argparse.ArgumentParser):
parser.add_argument("--rttm_list", type=str, required=True)
parser.add_argument("--wav_scp_list", type=str, required=True)
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--n_spk", type=int, default=8)
parser.add_argument("--remove_sil", default=False, action="store_true")
parser.add_argument("--max_overlap", default=0, type=int)
parser.add_argument("--frame_shift", type=float, default=0.01)
args = parser.parse_args()
rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()]
meeting2rttm = OrderedDict()
for rttm_path in rttm_list:
meeting2rttm.update(self.load_rttm(rttm_path))
wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()]
meeting_scp = OrderedDict()
for scp_path in wav_scp_list:
meeting_scp.update(load_scp_as_dict(scp_path))
if len(meeting_scp) != len(meeting2rttm):
logging.warning("Number of wav and rttm mismatch {} != {}".format(
len(meeting_scp), len(meeting2rttm)))
common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
logging.warning("Keep {} records.".format(len(common_keys)))
new_meeting_scp = OrderedDict()
rm_keys = []
for key in meeting_scp:
if key not in common_keys:
rm_keys.append(key)
else:
new_meeting_scp[key] = meeting_scp[key]
logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys)))
new_meeting2rttm = OrderedDict()
rm_keys = []
for key in meeting2rttm:
if key not in common_keys:
rm_keys.append(key)
else:
new_meeting2rttm[key] = meeting2rttm[key]
logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys)))
meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
return task_list, None, args
@staticmethod
def load_rttm(rttm_path):
meeting2rttm = OrderedDict()
for one_line in open(rttm_path, "rt", encoding="utf-8"):
mid = one_line.strip().split(" ")[1]
if mid not in meeting2rttm:
meeting2rttm[mid] = []
meeting2rttm[mid].append(one_line.strip())
return meeting2rttm
def post(self, results_list, args):
pass
def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0,
sr=None, frame_shift=0.01):
frame_shift = int(frame_shift * sr)
num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
multi_label = np.zeros([n_spk, num_frame], dtype=np.float32)
for _, st, dur, spk in spk_turns:
idx = spk_list.index(spk)
st, dur = int(st * sr), int(dur * sr)
frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
multi_label[idx, frame_st:frame_ed] = 1
if remove_sil:
speech_count = np.sum(multi_label, axis=0)
idx = np.nonzero(speech_count)[0]
multi_label = multi_label[:, idx]
if max_overlap > 0:
speech_count = np.sum(multi_label, axis=0)
idx = np.nonzero(speech_count <= max_overlap)[0]
multi_label = multi_label[:, idx]
label = multi_label.T
return label # (T, N)
def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0,
sr=16000, frame_shift=0.01):
wav, sr = soundfile.read(wav_path)
wav_len = len(wav)
spk_turns = []
spk_list = []
for one_line in rttms:
parts = one_line.strip().split(" ")
mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
if spk not in spk_list:
spk_list.append(spk)
spk_turns.append((mid, st, dur, spk))
labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift)
return labels, spk_list
def process(task_args):
task_idx, task_list, _, args = task_args
spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)),
"wt", encoding="utf-8")
out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1))
label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
for mid, wav_path, rttms in task_list:
meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap,
args.sr, args.frame_shift)
label_writer(mid, meeting_labels)
spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list)))
spk_list_writer.close()
label_writer.close()
return None
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -1,115 +0,0 @@
import numpy as np
import os
import argparse
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import soundfile as sf
from tqdm import tqdm
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
assert isinstance(parser, argparse.ArgumentParser)
parser.add_argument("wav_scp", type=str)
parser.add_argument("rttm", type=str)
parser.add_argument("out_dir", type=str)
parser.add_argument("--min_dur", type=float, default=2.0)
parser.add_argument("--max_spk_num", type=int, default=4)
args = parser.parse_args()
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
wav_scp = load_scp_as_list(args.wav_scp)
meeting2rttms = {}
for one_line in open(args.rttm, "rt"):
parts = [x for x in one_line.strip().split(" ") if x != ""]
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
if mid not in meeting2rttms:
meeting2rttms[mid] = []
meeting2rttms[mid].append(one_line)
task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
return task_list, None, args
def post(self, result_list, args):
count = [0, 0]
for result in result_list:
count[0] += result[0]
count[1] += result[1]
print("Found {} speakers, extracted {}.".format(count[1], count[0]))
# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
labels = np.zeros([max_spk_num, length], int)
spk_list = []
for one_line in rttms:
parts = [x for x in one_line.strip().split(" ") if x != ""]
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
if spk_name.isdigit():
spk_name = "{}_S{:03d}".format(mid, int(spk_name))
else:
spk_name = "{}_{}".format(mid, spk_name)
if spk_name not in spk_list:
spk_list.append(spk_name)
st, dur = int(st*sr), int(dur*sr)
idx = spk_list.index(spk_name)
labels[idx, st:st+dur] = 1
return labels, spk_list
def get_nonoverlap_turns(multi_label, spk_list):
turns = []
label = np.sum(multi_label, axis=0) == 1
spk, in_turn, st = None, False, 0
for i in range(len(label)):
if not in_turn and label[i]:
st, in_turn = i, True
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
if in_turn:
if not label[i]:
in_turn = False
turns.append([st, i, spk])
elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
turns.append([st, i, spk])
st, in_turn = i, True
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
if in_turn:
turns.append([st, len(label), spk])
return turns
def process(task_args):
task_id, task_list, _, args = task_args
spk_count = [0, 0]
for mid, wav_path, rttms in task_list:
wav, sr = sf.read(wav_path, dtype="int16")
assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
turns = get_nonoverlap_turns(multi_label, spk_list)
extracted_spk = []
count = 1
for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
if (ed - st) >= args.min_dur * args.sr:
seg = wav[st: ed]
save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
count += 1
if spk not in extracted_spk:
extracted_spk.append(spk)
if len(extracted_spk) != len(spk_list):
print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
))
spk_count[0] += len(extracted_spk)
spk_count[1] += len(spk_list)
return spk_count
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -1,73 +0,0 @@
import numpy as np
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import librosa
import argparse
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("dir", type=str)
parser.add_argument("out_dir", type=str)
parser.add_argument("--n_spk", type=int, default=4)
parser.add_argument("--remove_sil", default=False, action="store_true")
args = parser.parse_args()
meeting_scp = load_scp_as_dict(os.path.join(args.dir, "meeting.scp"))
rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
task_list = [(mid, meeting_scp[mid], rttm_path) for mid, rttm_path in rttm_scp]
return task_list, None, args
def post(self, results_list, args):
pass
def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=16000):
multi_label = np.zeros([n_spk, length], dtype=int)
for _, st, dur, spk in spk_turns:
st, dur = int(st * sr), int(dur * sr)
idx = spk_list.index(spk)
multi_label[idx, st:st+dur] = 1
if not remove_sil:
return multi_label.T
speech_count = np.sum(multi_label, axis=0)
idx = np.nonzero(speech_count)[0]
label = multi_label[:, idx].T
return label # (T, N)
def build_labels(wav_path, rttm_path, n_spk, remove_sil=False, sr=16000):
wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr)
spk_turns = []
spk_list = []
for one_line in open(rttm_path, "rt"):
parts = one_line.strip().split(" ")
mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
spk = "{}_S{:03d}".format(mid, spk)
if spk not in spk_list:
spk_list.append(spk)
spk_turns.append((mid, st, dur, spk))
labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil)
return labels
def process(task_args):
_, task_list, _, args = task_args
for mid, wav_path, rttm_path in task_list:
meeting_labels = build_labels(wav_path, rttm_path, args.n_spk, args.remove_sil)
save_path = os.path.join(args.out_dir, "{}.lbl".format(mid))
np.save(save_path, meeting_labels.astype(bool))
print(mid)
return None
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -1,53 +0,0 @@
import numpy as np
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import librosa
import soundfile as sf
from tqdm import tqdm
import argparse
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("wav_scp", type=str)
parser.add_argument("out_dir", type=str)
parser.add_argument("--chunk_dur", type=float, default=16)
parser.add_argument("--shift_dur", type=float, default=4)
args = parser.parse_args()
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
wav_scp = load_scp_as_list(args.wav_scp)
return wav_scp, None, args
def post(self, results_list, args):
pass
def process(task_args):
_, task_list, _, args = task_args
chunk_len, shift_len = int(args.chunk_dur * args.sr), int(args.shift_dur * args.sr)
for mid, wav_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
if not os.path.exists(os.path.join(args.out_dir, mid)):
os.makedirs(os.path.join(args.out_dir, mid))
wav = librosa.load(wav_path, args.sr, True)[0] * 32767
n_chunk = (len(wav) - chunk_len) // shift_len + 1
if (len(wav) - chunk_len) % shift_len > 0:
n_chunk += 1
for i in range(n_chunk):
seg = wav[i*shift_len: i*shift_len + chunk_len]
st = int(float(i*shift_len)/args.sr * 100)
dur = int(float(len(seg))/args.sr * 100)
file_name = "{}_S{:04d}_{:07d}_{:07d}.wav".format(mid, i, st, st+dur)
save_path = os.path.join(args.out_dir, mid, file_name)
sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
return None
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -1,57 +0,0 @@
import numpy as np
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import argparse
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("--rttm_scp", type=str)
parser.add_argument("--seg_file", type=str)
args = parser.parse_args()
if not os.path.exists(os.path.dirname(args.seg_file)):
os.makedirs(os.path.dirname(args.seg_file))
task_list = load_scp_as_list(args.rttm_scp)
return task_list, None, args
def post(self, results_list, args):
with open(args.seg_file, "wt", encoding="utf-8") as fd:
for results in results_list:
fd.writelines(results)
def process(task_args):
_, task_list, _, args = task_args
outputs = []
for mid, rttm_path in task_list:
spk_turns = []
length = 0
for one_line in open(rttm_path, 'rt', encoding="utf-8"):
parts = one_line.strip().split(" ")
_, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
st, ed = int(st*100), int((st + dur)*100)
length = ed if ed > length else length
spk_turns.append([mid, st, ed, spk_name])
is_sph = np.zeros((length+1, ), dtype=bool)
for _, st, ed, _ in spk_turns:
is_sph[st:ed] = True
st, in_speech = 0, False
for i in range(length+1):
if not in_speech and is_sph[i]:
st, in_speech = i, True
if in_speech and not is_sph[i]:
in_speech = False
outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format(
mid, st, i, mid, float(st)/100, float(i)/100
))
return outputs
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -1,138 +0,0 @@
import soundfile
import kaldiio
from tqdm import tqdm
import json
import os
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import numpy as np
import argparse
import random
short_spk_list = []
def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
all_utts = spk2utt[spk]
idx_list = list(range(len(all_utts)))
random.shuffle(idx_list)
count = 0
utt_list = []
for i in idx_list:
utt_id = all_utts[i]
utt_list.append(utt_id)
count += int(utt2frames[utt_id])
if count >= total_len:
break
if count < 300 and spk not in short_spk_list:
print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
short_spk_list.append(spk)
ivc_list = [kaldiio.load_mat(utt2ivc[utt]) for utt in utt_list]
ivc_list = [x/np.linalg.norm(x, axis=-1) for x in ivc_list]
ivc = np.concatenate(ivc_list, axis=0)
ivc = np.mean(ivc, axis=0, keepdims=False)
return ivc
def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args):
out_prefix = args.out
ivc_dim = 192
win_len, win_shift = 400, 160
label_weights = 2 ** np.array(list(range(args.n_spk)))
wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix))
ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix))
label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix))
frames_list = []
chunk_size = int(args.chunk_size * args.sr)
chunk_shift = int(args.chunk_shift * args.sr)
for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar):
meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32')
num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1
meeting_labels = np.load(labels_scp[mid])
for i in range(num_chunk):
st, ed = i*chunk_shift, i*chunk_shift+chunk_size
seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100))
wav_writer(seg_id, meeting_wav[st: ed])
xvec_list = []
for spk in meeting2spk_list[mid]:
spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000)
xvec_list.append(spk_xvec)
for _ in range(args.n_spk - len(xvec_list)):
xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
xvec = np.row_stack(xvec_list)
ivc_writer(seg_id, xvec)
wav_label = meeting_labels[st:ed, :]
frame_num = (ed-st) // win_shift
# wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant")
feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=np.float32)
for i in range(frame_num):
frame_label = wav_label[i*win_shift: (i+1)*win_shift, :]
feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(np.float32)
label_writer(seg_id, feat_label)
frames_list.append((mid, feat_label.shape[0]))
return frames_list
def calc_spk_list(rttm_path):
spk_list = []
for one_line in open(rttm_path, "rt"):
parts = one_line.strip().split(" ")
mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
spk = "{}_S{:03d}".format(mid, spk)
if spk not in spk_list:
spk_list.append(spk)
return spk_list
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dir", required=True, type=str, default=None,
help="feats.scp")
parser.add_argument("--out", required=True, type=str, default=None,
help="The prefix of dumpped files.")
parser.add_argument("--n_spk", type=int, default=4)
parser.add_argument("--use_lfr", default=False, action="store_true")
parser.add_argument("--no_pbar", default=False, action="store_true")
parser.add_argument("--sr", type=int, default=16000)
parser.add_argument("--chunk_size", type=int, default=16)
parser.add_argument("--chunk_shift", type=int, default=4)
args = parser.parse_args()
if not os.path.exists(os.path.dirname(args.out)):
os.makedirs(os.path.dirname(args.out))
meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp"))
labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
utt2frames = {}
for uttid, wav_path in utt2wav.items():
wav, sr = soundfile.read(wav_path, dtype="int16")
utt2frames[uttid] = int(len(wav) / sr * 100)
meeting2spk_list = {}
for mid, rttm_path in rttm_scp:
meeting2spk_list[mid] = calc_spk_list(rttm_path)
spk2utt = {}
for utt, spk in utt2spk.items():
if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
if spk not in spk2utt:
spk2utt[spk] = []
spk2utt[spk].append(utt)
# random.shuffle(feat_scp)
meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args)
total_frames = sum([x[1] for x in meeting_lens])
print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
if __name__ == '__main__':
main()

View File

@ -1,110 +0,0 @@
from __future__ import print_function
import numpy as np
import os
import sys
import argparse
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import librosa
import soundfile as sf
from copy import deepcopy
import json
from tqdm import tqdm
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
assert isinstance(parser, argparse.ArgumentParser)
parser.add_argument("wav_scp", type=str)
parser.add_argument("rttm_scp", type=str)
parser.add_argument("out_dir", type=str)
parser.add_argument("--min_dur", type=float, default=2.0)
parser.add_argument("--max_spk_num", type=int, default=4)
args = parser.parse_args()
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
wav_scp = load_scp_as_list(args.wav_scp)
rttm_scp = load_scp_as_dict(args.rttm_scp)
task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp]
return task_list, None, args
def post(self, result_list, args):
count = [0, 0]
for result in result_list:
count[0] += result[0]
count[1] += result[1]
print("Found {} speakers, extracted {}.".format(count[1], count[0]))
# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4):
labels = np.zeros([max_spk_num, length], int)
spk_list = []
for one_line in open(rttm_path, 'rt'):
parts = one_line.strip().split(" ")
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
if spk_name.isdigit():
spk_name = "{}_S{:03d}".format(mid, int(spk_name))
if spk_name not in spk_list:
spk_list.append(spk_name)
st, dur = int(st*sr), int(dur*sr)
idx = spk_list.index(spk_name)
labels[idx, st:st+dur] = 1
return labels, spk_list
def get_nonoverlap_turns(multi_label, spk_list):
turns = []
label = np.sum(multi_label, axis=0) == 1
spk, in_turn, st = None, False, 0
for i in range(len(label)):
if not in_turn and label[i]:
st, in_turn = i, True
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
if in_turn:
if not label[i]:
in_turn = False
turns.append([st, i, spk])
elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
turns.append([st, i, spk])
st, in_turn = i, True
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
if in_turn:
turns.append([st, len(label), spk])
return turns
def process(task_args):
task_id, task_list, _, args = task_args
spk_count = [0, 0]
for mid, wav_path, rttm_path in task_list:
wav, sr = sf.read(wav_path, dtype="int16")
assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num)
turns = get_nonoverlap_turns(multi_label, spk_list)
extracted_spk = []
count = 1
for st, ed, spk in tqdm(turns, total=len(turns), ascii=True):
if (ed - st) >= args.min_dur * args.sr:
seg = wav[st: ed]
save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
count += 1
if spk not in extracted_spk:
extracted_spk.append(spk)
if len(extracted_spk) != len(spk_list):
print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
))
spk_count[0] += len(extracted_spk)
spk_count[1] += len(spk_list)
return spk_count
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -1,60 +0,0 @@
import numpy as np
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import librosa
import soundfile as sf
import argparse
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("dir", type=str)
parser.add_argument("out_dir", type=str)
args = parser.parse_args()
meeting_scp = load_scp_as_list(os.path.join(args.dir, "meeting.scp"))
vad_file = open(os.path.join(args.dir, "segments"), encoding="utf-8")
meeting2vad = {}
for one_line in vad_file:
uid, mid, st, ed = one_line.strip().split(" ")
st, ed = int(float(st) * args.sr), int(float(ed) * args.sr)
if mid not in meeting2vad:
meeting2vad[mid] = []
meeting2vad[mid].append((uid, st, ed))
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp]
return task_list, None, args
def post(self, results_list, args):
pass
def process(task_args):
_, task_list, _, args = task_args
for mid, wav_path, vad_list in task_list:
wav = librosa.load(wav_path, args.sr, True)[0] * 32767
seg_list = []
pos_map = []
offset = 0
for uid, st, ed in vad_list:
seg_list.append(wav[st: ed])
pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st))
offset = offset + ed - st
out = np.concatenate(seg_list, axis=0)
save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
map_path = os.path.join(args.out_dir, "{}.pos".format(mid))
with open(map_path, "wt", encoding="utf-8") as fd:
fd.writelines(pos_map)
print(mid)
return None
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -1,261 +0,0 @@
import logging
import numpy as np
import soundfile
import kaldiio
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import argparse
from collections import OrderedDict
import random
from typing import List, Dict
from copy import deepcopy
import json
logging.basicConfig(
level="INFO",
format=f"[{os.uname()[1].split('.')[0]}]"
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser: argparse.ArgumentParser):
parser.add_argument("--label_scp", type=str, required=True)
parser.add_argument("--wav_scp", type=str, required=True)
parser.add_argument("--utt2spk", type=str, required=True)
parser.add_argument("--spk2meeting", type=str, required=True)
parser.add_argument("--utt2xvec", type=str, required=True)
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--chunk_size", type=float, default=16)
parser.add_argument("--chunk_shift", type=float, default=4)
parser.add_argument("--frame_shift", type=float, default=0.01)
parser.add_argument("--embedding_dim", type=int, default=None)
parser.add_argument("--average_emb_num", type=int, default=0)
parser.add_argument("--subset", type=int, default=0)
parser.add_argument("--data_json", type=str, default=None)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--log_interval", type=int, default=100)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
logging.info("Loading data...")
if not os.path.exists(args.data_json):
label_list = load_scp_as_list(args.label_scp)
wav_scp = load_scp_as_dict(args.wav_scp)
utt2spk = load_scp_as_dict(args.utt2spk)
utt2xvec = load_scp_as_dict(args.utt2xvec)
spk2meeting = load_scp_as_dict(args.spk2meeting)
meeting2spks = OrderedDict()
for spk, meeting in spk2meeting.items():
if meeting not in meeting2spks:
meeting2spks[meeting] = []
meeting2spks[meeting].append(spk)
spk2utts = OrderedDict()
for utt, spk in utt2spk.items():
if spk not in spk2utts:
spk2utts[spk] = []
spk2utts[spk].append(utt)
os.makedirs(os.path.dirname(args.data_json), exist_ok=True)
logging.info("Dump data...")
json.dump({
"label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec,
"spk2utts": spk2utts, "meeting2spks": meeting2spks
}, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
else:
data_dict = json.load(open(args.data_json, "rt", encoding="utf-8"))
label_list = data_dict["label_list"]
wav_scp = data_dict["wav_scp"]
utt2xvec = data_dict["utt2xvec"]
spk2utts = data_dict["spk2utts"]
meeting2spks = data_dict["meeting2spks"]
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
args.chunk_size = int(args.chunk_size / args.frame_shift)
args.chunk_shift = int(args.chunk_shift / args.frame_shift)
if args.embedding_dim is None:
args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1]
logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format(
len(wav_scp), len(spk2utts), len(meeting2spks)
))
return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
def post(self, results_list, args):
logging.info("[main]: Got {} chunks.".format(sum(results_list)))
def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
utt_list = spk2utts[spk]
wav_list = []
cur_length = 0
while cur_length < sample_length:
uttid = random.choice(utt_list)
wav, fs = soundfile.read(wav_scp[uttid], dtype='float32')
wav_list.append(wav)
cur_length += len(wav)
concat_wav = np.concatenate(wav_list, axis=0)
start = random.randint(0, len(concat_wav) - sample_length)
return concat_wav[start: start+sample_length]
def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num):
# process for dummy speaker
if spk == "None":
return np.zeros((1, embedding_dim), dtype=np.float32)
# calculate averaged speaker embeddings
utt_list = spk2utts[spk]
if average_emb_num == 0 or average_emb_num > len(utt_list):
xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list]
else:
xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)]
xvec = np.concatenate(xvec_list, axis=0)
xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True)
xvec = np.mean(xvec, axis=0)
return xvec
def simu_chunk(
frame_label: np.ndarray,
sample_label: np.ndarray,
wav_scp: Dict[str, str],
utt2xvec: Dict[str, str],
spk2utts: Dict[str, List[str]],
meeting2spks: Dict[str, List[str]],
all_speaker_list: List[str],
meeting_list: List[str],
embedding_dim: int,
average_emb_num: int,
):
frame_length, max_spk_num = frame_label.shape
sample_length = sample_label.shape[0]
positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0))
pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)])
# get positive speakers
if len(pos_speaker_list) >= positive_speaker_num:
pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num)
else:
while len(pos_speaker_list) < positive_speaker_num:
_spk = random.choice(all_speaker_list)
if _spk not in pos_speaker_list:
pos_speaker_list.append(_spk)
# get negative speakers
negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num)
neg_speaker_list = []
while len(neg_speaker_list) < negative_speaker_num:
_spk = random.choice(all_speaker_list)
if _spk not in pos_speaker_list and _spk not in neg_speaker_list:
neg_speaker_list.append(_spk)
neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num))
random.shuffle(pos_speaker_list)
random.shuffle(neg_speaker_list)
seperated_wav = np.zeros(sample_label.shape, dtype=np.float32)
this_spk_list = []
for idx, frame_num in enumerate(frame_label.sum(axis=0)):
if frame_num > 0:
spk = pos_speaker_list.pop(0)
this_spk_list.append(spk)
simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length)
seperated_wav[:, idx] = simu_spk_wav
else:
spk = neg_speaker_list.pop(0)
this_spk_list.append(spk)
# calculate mixed wav
mixed_wav = np.sum(seperated_wav * sample_label, axis=1)
# shuffle the order of speakers
shuffle_idx = list(range(max_spk_num))
random.shuffle(shuffle_idx)
this_spk_list = [this_spk_list[x] for x in shuffle_idx]
seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose()
frame_label = frame_label.transpose()[shuffle_idx].transpose()
# calculate profile
profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
for spk in this_spk_list]
profile = np.vstack(profile)
# pse_weights = 2 ** np.arange(max_spk_num)
# pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1)
# pse_label = pse_label.astype(str).tolist()
return mixed_wav, seperated_wav, profile, frame_label
def process(task_args):
task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj))
out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
# out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
# wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1))
profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1))
label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys())
labels_list = []
total_chunks = 0
for org_mid, label_path in task_list:
whole_label = kaldiio.load_mat(label_path)
# random offset to keep diversity
rand_shift = random.randint(0, args.chunk_shift)
num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1
labels_list.append((org_mid, whole_label, rand_shift, num_chunk))
total_chunks += num_chunk
idx = 0
simu_chunk_count = 0
for org_mid, whole_label, rand_shift, num_chunk in labels_list:
for i in range(num_chunk):
idx = idx + 1
st = i * args.chunk_shift + rand_shift
ed = i * args.chunk_shift + args.chunk_size + rand_shift
utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format(
args.subset + 1, task_idx + 1, org_mid, st, ed
)
frame_label = whole_label[st: ed, :]
sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0)
mix_wav, seg_wav, profile, frame_label = simu_chunk(
frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks,
speaker_list, meeting_list, args.embedding_dim, args.average_emb_num
)
wav_mix_writer(utt_id, mix_wav)
# wav_sep_writer(utt_id, seg_wav)
profile_writer(utt_id, profile)
label_writer(utt_id, frame_label)
simu_chunk_count += 1
if simu_chunk_count % args.log_interval == 0:
logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format(
task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id))
wav_mix_writer.close()
# wav_sep_writer.close()
profile_writer.close()
label_writer.close()
logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count))
return simu_chunk_count
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()