diff --git a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml b/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml deleted file mode 100644 index 459a7414f..000000000 --- a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml +++ /dev/null @@ -1,121 +0,0 @@ -model: sond -model_conf: - lsm_weight: 0.0 - length_normalized_loss: true - max_spk_num: 16 - -# speech encoder -encoder: ecapa_tdnn -encoder_conf: - # pass by model, equal to feature dim - # input_size: 80 - pool_size: 20 - stride: 1 -speaker_encoder: conv -speaker_encoder_conf: - input_units: 256 - num_layers: 3 - num_units: 256 - kernel_size: 1 - dropout_rate: 0.0 - position_encoder: null - out_units: 256 - out_norm: false - auxiliary_states: false - tf2torch_tensor_name_prefix_torch: speaker_encoder - tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder -ci_scorer: dot -ci_scorer_conf: {} -cd_scorer: san -cd_scorer_conf: - input_size: 512 - output_size: 512 - out_units: 1 - attention_heads: 4 - linear_units: 1024 - num_blocks: 4 - dropout_rate: 0.0 - positional_dropout_rate: 0.0 - attention_dropout_rate: 0.0 - # use string "null" to remove input layer - input_layer: "null" - pos_enc_class: null - normalize_before: true - tf2torch_tensor_name_prefix_torch: cd_scorer - tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer -# post net -decoder: fsmn -decoder_conf: - in_units: 32 - out_units: 2517 - filter_size: 31 - fsmn_num_layers: 6 - dnn_num_layers: 1 - num_memory_units: 512 - ffn_inner_dim: 512 - dropout_rate: 0.0 - tf2torch_tensor_name_prefix_torch: decoder - tf2torch_tensor_name_prefix_tf: EAND/post_net -frontend: wav_frontend -frontend_conf: - fs: 16000 - window: povey - n_mels: 80 - frame_length: 25 - frame_shift: 10 - filter_length_min: -1 - filter_length_max: -1 - lfr_m: 1 - lfr_n: 1 - dither: 0.0 - snip_edges: false - -# minibatch related -batch_type: length -# 16s * 16k * 16 samples -batch_bins: 4096000 -num_workers: 8 - -# optimization related -accum_grad: 1 -grad_clip: 5 -max_epoch: 50 -val_scheduler_criterion: - - valid - - acc -best_model_criterion: -- - valid - - der - - min -- - valid - - forward_steps - - max -keep_nbest_models: 10 - -optim: adam -optim_conf: - lr: 0.001 -scheduler: warmuplr -scheduler_conf: - warmup_steps: 10000 - -# without spec aug -specaug: null -specaug_conf: - apply_time_warp: true - time_warp_window: 5 - time_warp_mode: bicubic - apply_freq_mask: true - freq_mask_width_range: - - 0 - - 30 - num_freq_mask: 2 - apply_time_mask: true - time_mask_width_range: - - 0 - - 40 - num_time_mask: 2 - -log_interval: 50 -# without normalize -normalize: None diff --git a/egs/mars/sd/local_run.sh b/egs/mars/sd/local_run.sh deleted file mode 100755 index 4516e9f96..000000000 --- a/egs/mars/sd/local_run.sh +++ /dev/null @@ -1,171 +0,0 @@ -#!/usr/bin/env bash - -. ./path.sh || exit 1; - -# machines configuration -CUDA_VISIBLE_DEVICES="6,7" -gpu_num=2 -count=1 -gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding -# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob -njob=5 -train_cmd=utils/run.pl -infer_cmd=utils/run.pl - -# general configuration -feats_dir="." #feature output dictionary -exp_dir="." -lang=zh -dumpdir=dump/raw -feats_type=raw -token_type=char -scp=wav.scp -type=kaldi_ark -stage=3 -stop_stage=4 - -# feature configuration -feats_dim= -sample_frequency=16000 -nj=32 -speed_perturb= - -# exp tag -tag="exp1" - -. utils/parse_options.sh || exit 1; - -# Set bash to 'debug' mode, it will exit on : -# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', -set -e -set -u -set -o pipefail - -train_set=train -valid_set=dev -test_sets="dev test" - -asr_config=conf/train_asr_conformer.yaml -model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}" - -inference_config=conf/decode_asr_transformer.yaml -inference_asr_model=valid.acc.ave_10best.pb - -# you can set gpu num for decoding here -gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default -ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') - -if ${gpu_inference}; then - inference_nj=$[${ngpu}*${njob}] - _ngpu=1 -else - inference_nj=$njob - _ngpu=0 -fi - -feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir} -feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir} -feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir} - -# Training Stage -world_size=$gpu_num # run on one machine -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - echo "stage 3: Training" - mkdir -p ${exp_dir}/exp/${model_dir} - mkdir -p ${exp_dir}/exp/${model_dir}/log - INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init - if [ -f $INIT_FILE ];then - rm -f $INIT_FILE - fi - init_method=file://$(readlink -f $INIT_FILE) - echo "$0: init method is $init_method" - for ((i = 0; i < $gpu_num; ++i)); do - { - rank=$i - local_rank=$i - gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) - asr_train.py \ - --gpu_id $gpu_id \ - --use_preprocessor true \ - --token_type char \ - --token_list $token_list \ - --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \ - --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \ - --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \ - --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \ - --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \ - --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \ - --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \ - --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \ - --resume true \ - --output_dir ${exp_dir}/exp/${model_dir} \ - --config $asr_config \ - --input_size $feats_dim \ - --ngpu $gpu_num \ - --num_worker_count $count \ - --multiprocessing_distributed true \ - --dist_init_method $init_method \ - --dist_world_size $world_size \ - --dist_rank $rank \ - --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1 - } & - done - wait -fi - -# Testing Stage -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - echo "stage 4: Inference" - for dset in ${test_sets}; do - asr_exp=${exp_dir}/exp/${model_dir} - inference_tag="$(basename "${inference_config}" .yaml)" - _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" - _logdir="${_dir}/logdir" - if [ -d ${_dir} ]; then - echo "${_dir} is already exists. if you want to decode again, please delete this dir first." - exit 0 - fi - mkdir -p "${_logdir}" - _data="${feats_dir}/${dumpdir}/${dset}" - key_file=${_data}/${scp} - num_scp_file="$(<${key_file} wc -l)" - _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") - split_scps= - for n in $(seq "${_nj}"); do - split_scps+=" ${_logdir}/keys.${n}.scp" - done - # shellcheck disable=SC2086 - utils/split_scp.pl "${key_file}" ${split_scps} - _opts= - if [ -n "${inference_config}" ]; then - _opts+="--config ${inference_config} " - fi - ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \ - python -m funasr.bin.asr_inference_launch \ - --batch_size 1 \ - --ngpu "${_ngpu}" \ - --njob ${njob} \ - --gpuid_list ${gpuid_list} \ - --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ - --key_file "${_logdir}"/keys.JOB.scp \ - --asr_train_config "${asr_exp}"/config.yaml \ - --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ - --output_dir "${_logdir}"/output.JOB \ - --mode asr \ - ${_opts} - - for f in token token_int score text; do - if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then - for i in $(seq "${_nj}"); do - cat "${_logdir}/output.${i}/1best_recog/${f}" - done | sort -k1 >"${_dir}/${f}" - fi - done - python utils/proce_text.py ${_dir}/text ${_dir}/text.proc - python utils/proce_text.py ${_data}/text ${_data}/text.proc - python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer - tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt - cat ${_dir}/text.cer.txt - done -fi - diff --git a/egs/mars/sd/path.sh b/egs/mars/sd/path.sh deleted file mode 100755 index 7972642d0..000000000 --- a/egs/mars/sd/path.sh +++ /dev/null @@ -1,5 +0,0 @@ -export FUNASR_DIR=$PWD/../../.. - -# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 -export PATH=$FUNASR_DIR/funasr/bin:$PATH diff --git a/egs/mars/sd/scripts/calculate_shapes.py b/egs/mars/sd/scripts/calculate_shapes.py deleted file mode 100644 index b207f2d3a..000000000 --- a/egs/mars/sd/scripts/calculate_shapes.py +++ /dev/null @@ -1,45 +0,0 @@ -import logging -import numpy as np -import soundfile -import kaldiio -from funasr.utils.job_runner import MultiProcessRunnerV3 -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import os -import argparse -from collections import OrderedDict - - -class MyRunner(MultiProcessRunnerV3): - - def prepare(self, parser: argparse.ArgumentParser): - parser.add_argument("--input_scp", type=str, required=True) - parser.add_argument("--out_path") - args = parser.parse_args() - - if not os.path.exists(os.path.dirname(args.out_path)): - os.makedirs(os.path.dirname(args.out_path)) - - task_list = load_scp_as_list(args.input_scp) - return task_list, None, args - - def post(self, result_list, args): - fd = open(args.out_path, "wt", encoding="utf-8") - for results in result_list: - for uttid, shape in results: - fd.write("{} {}\n".format(uttid, ",".join(shape))) - fd.close() - - -def process(task_args): - task_idx, task_list, _, args = task_args - rst = [] - for uttid, file_path in task_list: - data = kaldiio.load_mat(file_path) - shape = [str(x) for x in data.shape] - rst.append((uttid, shape)) - return rst - - -if __name__ == '__main__': - my_runner = MyRunner(process) - my_runner.run() diff --git a/egs/mars/sd/scripts/dump_rttm_to_labels.py b/egs/mars/sd/scripts/dump_rttm_to_labels.py deleted file mode 100644 index ec1c76568..000000000 --- a/egs/mars/sd/scripts/dump_rttm_to_labels.py +++ /dev/null @@ -1,140 +0,0 @@ -import logging -import numpy as np -import soundfile -import kaldiio -from funasr.utils.job_runner import MultiProcessRunnerV3 -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import os -import argparse -from collections import OrderedDict - - -class MyRunner(MultiProcessRunnerV3): - - def prepare(self, parser: argparse.ArgumentParser): - parser.add_argument("--rttm_list", type=str, required=True) - parser.add_argument("--wav_scp_list", type=str, required=True) - parser.add_argument("--out_dir", type=str, required=True) - parser.add_argument("--n_spk", type=int, default=8) - parser.add_argument("--remove_sil", default=False, action="store_true") - parser.add_argument("--max_overlap", default=0, type=int) - parser.add_argument("--frame_shift", type=float, default=0.01) - args = parser.parse_args() - - rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()] - meeting2rttm = OrderedDict() - for rttm_path in rttm_list: - meeting2rttm.update(self.load_rttm(rttm_path)) - - wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()] - meeting_scp = OrderedDict() - for scp_path in wav_scp_list: - meeting_scp.update(load_scp_as_dict(scp_path)) - - if len(meeting_scp) != len(meeting2rttm): - logging.warning("Number of wav and rttm mismatch {} != {}".format( - len(meeting_scp), len(meeting2rttm))) - common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys()) - logging.warning("Keep {} records.".format(len(common_keys))) - new_meeting_scp = OrderedDict() - rm_keys = [] - for key in meeting_scp: - if key not in common_keys: - rm_keys.append(key) - else: - new_meeting_scp[key] = meeting_scp[key] - logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys))) - - new_meeting2rttm = OrderedDict() - rm_keys = [] - for key in meeting2rttm: - if key not in common_keys: - rm_keys.append(key) - else: - new_meeting2rttm[key] = meeting2rttm[key] - logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys))) - meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm - if not os.path.exists(args.out_dir): - os.makedirs(args.out_dir) - - task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()] - return task_list, None, args - - @staticmethod - def load_rttm(rttm_path): - meeting2rttm = OrderedDict() - for one_line in open(rttm_path, "rt", encoding="utf-8"): - mid = one_line.strip().split(" ")[1] - if mid not in meeting2rttm: - meeting2rttm[mid] = [] - meeting2rttm[mid].append(one_line.strip()) - - return meeting2rttm - - def post(self, results_list, args): - pass - - -def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0, - sr=None, frame_shift=0.01): - frame_shift = int(frame_shift * sr) - num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift) - multi_label = np.zeros([n_spk, num_frame], dtype=np.float32) - for _, st, dur, spk in spk_turns: - idx = spk_list.index(spk) - - st, dur = int(st * sr), int(dur * sr) - frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift) - frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift) - multi_label[idx, frame_st:frame_ed] = 1 - - if remove_sil: - speech_count = np.sum(multi_label, axis=0) - idx = np.nonzero(speech_count)[0] - multi_label = multi_label[:, idx] - - if max_overlap > 0: - speech_count = np.sum(multi_label, axis=0) - idx = np.nonzero(speech_count <= max_overlap)[0] - multi_label = multi_label[:, idx] - - label = multi_label.T - return label # (T, N) - - -def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0, - sr=16000, frame_shift=0.01): - wav, sr = soundfile.read(wav_path) - wav_len = len(wav) - spk_turns = [] - spk_list = [] - for one_line in rttms: - parts = one_line.strip().split(" ") - mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7] - if spk not in spk_list: - spk_list.append(spk) - spk_turns.append((mid, st, dur, spk)) - labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift) - return labels, spk_list - - -def process(task_args): - task_idx, task_list, _, args = task_args - spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)), - "wt", encoding="utf-8") - out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1)) - label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) - for mid, wav_path, rttms in task_list: - meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap, - args.sr, args.frame_shift) - label_writer(mid, meeting_labels) - spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list))) - - spk_list_writer.close() - label_writer.close() - return None - - -if __name__ == '__main__': - my_runner = MyRunner(process) - my_runner.run() diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py deleted file mode 100644 index cd1ec7b34..000000000 --- a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py +++ /dev/null @@ -1,115 +0,0 @@ -import numpy as np -import os -import argparse -from funasr.utils.job_runner import MultiProcessRunnerV3 -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import soundfile as sf -from tqdm import tqdm - - -class MyRunner(MultiProcessRunnerV3): - def prepare(self, parser): - assert isinstance(parser, argparse.ArgumentParser) - parser.add_argument("wav_scp", type=str) - parser.add_argument("rttm", type=str) - parser.add_argument("out_dir", type=str) - parser.add_argument("--min_dur", type=float, default=2.0) - parser.add_argument("--max_spk_num", type=int, default=4) - args = parser.parse_args() - - if not os.path.exists(args.out_dir): - os.makedirs(args.out_dir) - - wav_scp = load_scp_as_list(args.wav_scp) - meeting2rttms = {} - for one_line in open(args.rttm, "rt"): - parts = [x for x in one_line.strip().split(" ") if x != ""] - mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] - if mid not in meeting2rttms: - meeting2rttms[mid] = [] - meeting2rttms[mid].append(one_line) - - task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp] - return task_list, None, args - - def post(self, result_list, args): - count = [0, 0] - for result in result_list: - count[0] += result[0] - count[1] += result[1] - print("Found {} speakers, extracted {}.".format(count[1], count[0])) - - -# SPEAKER R8001_M8004_MS801 1 6.90 11.39 1 -def calc_multi_label(rttms, length, sr=8000, max_spk_num=4): - labels = np.zeros([max_spk_num, length], int) - spk_list = [] - for one_line in rttms: - parts = [x for x in one_line.strip().split(" ") if x != ""] - mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] - spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "") - if spk_name.isdigit(): - spk_name = "{}_S{:03d}".format(mid, int(spk_name)) - else: - spk_name = "{}_{}".format(mid, spk_name) - if spk_name not in spk_list: - spk_list.append(spk_name) - st, dur = int(st*sr), int(dur*sr) - idx = spk_list.index(spk_name) - labels[idx, st:st+dur] = 1 - return labels, spk_list - - -def get_nonoverlap_turns(multi_label, spk_list): - turns = [] - label = np.sum(multi_label, axis=0) == 1 - spk, in_turn, st = None, False, 0 - for i in range(len(label)): - if not in_turn and label[i]: - st, in_turn = i, True - spk = spk_list[np.argmax(multi_label[:, i], axis=0)] - if in_turn: - if not label[i]: - in_turn = False - turns.append([st, i, spk]) - elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]: - turns.append([st, i, spk]) - st, in_turn = i, True - spk = spk_list[np.argmax(multi_label[:, i], axis=0)] - if in_turn: - turns.append([st, len(label), spk]) - return turns - - -def process(task_args): - task_id, task_list, _, args = task_args - spk_count = [0, 0] - for mid, wav_path, rttms in task_list: - wav, sr = sf.read(wav_path, dtype="int16") - assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr) - multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num) - turns = get_nonoverlap_turns(multi_label, spk_list) - extracted_spk = [] - count = 1 - for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar): - if (ed - st) >= args.min_dur * args.sr: - seg = wav[st: ed] - save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count)) - if not os.path.exists(os.path.dirname(save_path)): - os.makedirs(os.path.dirname(save_path)) - sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) - count += 1 - if spk not in extracted_spk: - extracted_spk.append(spk) - if len(extracted_spk) != len(spk_list): - print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format( - mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk]) - )) - spk_count[0] += len(extracted_spk) - spk_count[1] += len(spk_list) - return spk_count - - -if __name__ == '__main__': - my_runner = MyRunner(process) - my_runner.run() diff --git a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py b/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py deleted file mode 100644 index e579f51b9..000000000 --- a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py +++ /dev/null @@ -1,73 +0,0 @@ -import numpy as np -from funasr.utils.job_runner import MultiProcessRunnerV3 -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import os -import librosa -import argparse - - -class MyRunner(MultiProcessRunnerV3): - - def prepare(self, parser): - parser.add_argument("dir", type=str) - parser.add_argument("out_dir", type=str) - parser.add_argument("--n_spk", type=int, default=4) - parser.add_argument("--remove_sil", default=False, action="store_true") - args = parser.parse_args() - - meeting_scp = load_scp_as_dict(os.path.join(args.dir, "meeting.scp")) - rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp")) - - if not os.path.exists(args.out_dir): - os.makedirs(args.out_dir) - - task_list = [(mid, meeting_scp[mid], rttm_path) for mid, rttm_path in rttm_scp] - return task_list, None, args - - def post(self, results_list, args): - pass - - -def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=16000): - multi_label = np.zeros([n_spk, length], dtype=int) - for _, st, dur, spk in spk_turns: - st, dur = int(st * sr), int(dur * sr) - idx = spk_list.index(spk) - multi_label[idx, st:st+dur] = 1 - if not remove_sil: - return multi_label.T - - speech_count = np.sum(multi_label, axis=0) - idx = np.nonzero(speech_count)[0] - label = multi_label[:, idx].T - return label # (T, N) - - -def build_labels(wav_path, rttm_path, n_spk, remove_sil=False, sr=16000): - wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr) - spk_turns = [] - spk_list = [] - for one_line in open(rttm_path, "rt"): - parts = one_line.strip().split(" ") - mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7]) - spk = "{}_S{:03d}".format(mid, spk) - if spk not in spk_list: - spk_list.append(spk) - spk_turns.append((mid, st, dur, spk)) - labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil) - return labels - - -def process(task_args): - _, task_list, _, args = task_args - for mid, wav_path, rttm_path in task_list: - meeting_labels = build_labels(wav_path, rttm_path, args.n_spk, args.remove_sil) - save_path = os.path.join(args.out_dir, "{}.lbl".format(mid)) - np.save(save_path, meeting_labels.astype(bool)) - print(mid) - return None - - -if __name__ == '__main__': - my_runner = MyRunner(process) - my_runner.run() diff --git a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py b/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py deleted file mode 100644 index 11bc39503..000000000 --- a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np -from funasr.utils.job_runner import MultiProcessRunnerV3 -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import os -import librosa -import soundfile as sf -from tqdm import tqdm -import argparse - - -class MyRunner(MultiProcessRunnerV3): - - def prepare(self, parser): - parser.add_argument("wav_scp", type=str) - parser.add_argument("out_dir", type=str) - parser.add_argument("--chunk_dur", type=float, default=16) - parser.add_argument("--shift_dur", type=float, default=4) - args = parser.parse_args() - - if not os.path.exists(args.out_dir): - os.makedirs(args.out_dir) - - wav_scp = load_scp_as_list(args.wav_scp) - return wav_scp, None, args - - def post(self, results_list, args): - pass - - -def process(task_args): - _, task_list, _, args = task_args - chunk_len, shift_len = int(args.chunk_dur * args.sr), int(args.shift_dur * args.sr) - for mid, wav_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar): - if not os.path.exists(os.path.join(args.out_dir, mid)): - os.makedirs(os.path.join(args.out_dir, mid)) - - wav = librosa.load(wav_path, args.sr, True)[0] * 32767 - n_chunk = (len(wav) - chunk_len) // shift_len + 1 - if (len(wav) - chunk_len) % shift_len > 0: - n_chunk += 1 - for i in range(n_chunk): - seg = wav[i*shift_len: i*shift_len + chunk_len] - st = int(float(i*shift_len)/args.sr * 100) - dur = int(float(len(seg))/args.sr * 100) - file_name = "{}_S{:04d}_{:07d}_{:07d}.wav".format(mid, i, st, st+dur) - save_path = os.path.join(args.out_dir, mid, file_name) - sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) - return None - - -if __name__ == '__main__': - my_runner = MyRunner(process) - my_runner.run() diff --git a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py b/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py deleted file mode 100644 index 011bd7c6a..000000000 --- a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py +++ /dev/null @@ -1,57 +0,0 @@ -import numpy as np -from funasr.utils.job_runner import MultiProcessRunnerV3 -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import os -import argparse - - -class MyRunner(MultiProcessRunnerV3): - - def prepare(self, parser): - parser.add_argument("--rttm_scp", type=str) - parser.add_argument("--seg_file", type=str) - args = parser.parse_args() - - if not os.path.exists(os.path.dirname(args.seg_file)): - os.makedirs(os.path.dirname(args.seg_file)) - - task_list = load_scp_as_list(args.rttm_scp) - return task_list, None, args - - def post(self, results_list, args): - with open(args.seg_file, "wt", encoding="utf-8") as fd: - for results in results_list: - fd.writelines(results) - - -def process(task_args): - _, task_list, _, args = task_args - outputs = [] - for mid, rttm_path in task_list: - spk_turns = [] - length = 0 - for one_line in open(rttm_path, 'rt', encoding="utf-8"): - parts = one_line.strip().split(" ") - _, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] - st, ed = int(st*100), int((st + dur)*100) - length = ed if ed > length else length - spk_turns.append([mid, st, ed, spk_name]) - is_sph = np.zeros((length+1, ), dtype=bool) - for _, st, ed, _ in spk_turns: - is_sph[st:ed] = True - - st, in_speech = 0, False - for i in range(length+1): - if not in_speech and is_sph[i]: - st, in_speech = i, True - if in_speech and not is_sph[i]: - in_speech = False - outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format( - mid, st, i, mid, float(st)/100, float(i)/100 - )) - return outputs - - -if __name__ == '__main__': - my_runner = MyRunner(process) - my_runner.run() diff --git a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py b/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py deleted file mode 100644 index a2bcd390a..000000000 --- a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py +++ /dev/null @@ -1,138 +0,0 @@ -import soundfile -import kaldiio -from tqdm import tqdm -import json -import os -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import numpy as np -import argparse -import random - -short_spk_list = [] -def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000): - all_utts = spk2utt[spk] - idx_list = list(range(len(all_utts))) - random.shuffle(idx_list) - count = 0 - utt_list = [] - for i in idx_list: - utt_id = all_utts[i] - utt_list.append(utt_id) - count += int(utt2frames[utt_id]) - if count >= total_len: - break - if count < 300 and spk not in short_spk_list: - print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300)) - short_spk_list.append(spk) - - ivc_list = [kaldiio.load_mat(utt2ivc[utt]) for utt in utt_list] - ivc_list = [x/np.linalg.norm(x, axis=-1) for x in ivc_list] - ivc = np.concatenate(ivc_list, axis=0) - ivc = np.mean(ivc, axis=0, keepdims=False) - return ivc - - -def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args): - out_prefix = args.out - - ivc_dim = 192 - win_len, win_shift = 400, 160 - label_weights = 2 ** np.array(list(range(args.n_spk))) - wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix)) - ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix)) - label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix)) - - - frames_list = [] - chunk_size = int(args.chunk_size * args.sr) - chunk_shift = int(args.chunk_shift * args.sr) - for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar): - meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32') - num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1 - meeting_labels = np.load(labels_scp[mid]) - for i in range(num_chunk): - st, ed = i*chunk_shift, i*chunk_shift+chunk_size - seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100)) - wav_writer(seg_id, meeting_wav[st: ed]) - - xvec_list = [] - for spk in meeting2spk_list[mid]: - spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000) - xvec_list.append(spk_xvec) - for _ in range(args.n_spk - len(xvec_list)): - xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32)) - xvec = np.row_stack(xvec_list) - ivc_writer(seg_id, xvec) - - wav_label = meeting_labels[st:ed, :] - frame_num = (ed-st) // win_shift - # wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant") - feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=np.float32) - for i in range(frame_num): - frame_label = wav_label[i*win_shift: (i+1)*win_shift, :] - feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(np.float32) - label_writer(seg_id, feat_label) - - frames_list.append((mid, feat_label.shape[0])) - return frames_list - - -def calc_spk_list(rttm_path): - spk_list = [] - for one_line in open(rttm_path, "rt"): - parts = one_line.strip().split(" ") - mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7]) - spk = "{}_S{:03d}".format(mid, spk) - if spk not in spk_list: - spk_list.append(spk) - - return spk_list - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--dir", required=True, type=str, default=None, - help="feats.scp") - parser.add_argument("--out", required=True, type=str, default=None, - help="The prefix of dumpped files.") - parser.add_argument("--n_spk", type=int, default=4) - parser.add_argument("--use_lfr", default=False, action="store_true") - parser.add_argument("--no_pbar", default=False, action="store_true") - parser.add_argument("--sr", type=int, default=16000) - parser.add_argument("--chunk_size", type=int, default=16) - parser.add_argument("--chunk_shift", type=int, default=4) - args = parser.parse_args() - - if not os.path.exists(os.path.dirname(args.out)): - os.makedirs(os.path.dirname(args.out)) - - meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp")) - labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp")) - rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp")) - utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk")) - utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec")) - utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp")) - utt2frames = {} - for uttid, wav_path in utt2wav.items(): - wav, sr = soundfile.read(wav_path, dtype="int16") - utt2frames[uttid] = int(len(wav) / sr * 100) - - meeting2spk_list = {} - for mid, rttm_path in rttm_scp: - meeting2spk_list[mid] = calc_spk_list(rttm_path) - - spk2utt = {} - for utt, spk in utt2spk.items(): - if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25: - if spk not in spk2utt: - spk2utt[spk] = [] - spk2utt[spk].append(utt) - - # random.shuffle(feat_scp) - meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args) - total_frames = sum([x[1] for x in meeting_lens]) - print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames)) - - -if __name__ == '__main__': - main() diff --git a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py deleted file mode 100644 index 1d6f53e92..000000000 --- a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py +++ /dev/null @@ -1,110 +0,0 @@ -from __future__ import print_function -import numpy as np -import os -import sys -import argparse -from funasr.utils.job_runner import MultiProcessRunnerV3 -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import librosa -import soundfile as sf -from copy import deepcopy -import json -from tqdm import tqdm - - -class MyRunner(MultiProcessRunnerV3): - def prepare(self, parser): - assert isinstance(parser, argparse.ArgumentParser) - parser.add_argument("wav_scp", type=str) - parser.add_argument("rttm_scp", type=str) - parser.add_argument("out_dir", type=str) - parser.add_argument("--min_dur", type=float, default=2.0) - parser.add_argument("--max_spk_num", type=int, default=4) - args = parser.parse_args() - - if not os.path.exists(args.out_dir): - os.makedirs(args.out_dir) - - wav_scp = load_scp_as_list(args.wav_scp) - rttm_scp = load_scp_as_dict(args.rttm_scp) - task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp] - return task_list, None, args - - def post(self, result_list, args): - count = [0, 0] - for result in result_list: - count[0] += result[0] - count[1] += result[1] - print("Found {} speakers, extracted {}.".format(count[1], count[0])) - - -# SPEAKER R8001_M8004_MS801 1 6.90 11.39 1 -def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4): - labels = np.zeros([max_spk_num, length], int) - spk_list = [] - for one_line in open(rttm_path, 'rt'): - parts = one_line.strip().split(" ") - mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] - if spk_name.isdigit(): - spk_name = "{}_S{:03d}".format(mid, int(spk_name)) - if spk_name not in spk_list: - spk_list.append(spk_name) - st, dur = int(st*sr), int(dur*sr) - idx = spk_list.index(spk_name) - labels[idx, st:st+dur] = 1 - return labels, spk_list - - -def get_nonoverlap_turns(multi_label, spk_list): - turns = [] - label = np.sum(multi_label, axis=0) == 1 - spk, in_turn, st = None, False, 0 - for i in range(len(label)): - if not in_turn and label[i]: - st, in_turn = i, True - spk = spk_list[np.argmax(multi_label[:, i], axis=0)] - if in_turn: - if not label[i]: - in_turn = False - turns.append([st, i, spk]) - elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]: - turns.append([st, i, spk]) - st, in_turn = i, True - spk = spk_list[np.argmax(multi_label[:, i], axis=0)] - if in_turn: - turns.append([st, len(label), spk]) - return turns - - -def process(task_args): - task_id, task_list, _, args = task_args - spk_count = [0, 0] - for mid, wav_path, rttm_path in task_list: - wav, sr = sf.read(wav_path, dtype="int16") - assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr) - multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num) - turns = get_nonoverlap_turns(multi_label, spk_list) - extracted_spk = [] - count = 1 - for st, ed, spk in tqdm(turns, total=len(turns), ascii=True): - if (ed - st) >= args.min_dur * args.sr: - seg = wav[st: ed] - save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count)) - if not os.path.exists(os.path.dirname(save_path)): - os.makedirs(os.path.dirname(save_path)) - sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) - count += 1 - if spk not in extracted_spk: - extracted_spk.append(spk) - if len(extracted_spk) != len(spk_list): - print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format( - mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk]) - )) - spk_count[0] += len(extracted_spk) - spk_count[1] += len(spk_list) - return spk_count - - -if __name__ == '__main__': - my_runner = MyRunner(process) - my_runner.run() diff --git a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py b/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py deleted file mode 100644 index 8b3195f7c..000000000 --- a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py +++ /dev/null @@ -1,60 +0,0 @@ -import numpy as np -from funasr.utils.job_runner import MultiProcessRunnerV3 -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import os -import librosa -import soundfile as sf -import argparse - - -class MyRunner(MultiProcessRunnerV3): - - def prepare(self, parser): - parser.add_argument("dir", type=str) - parser.add_argument("out_dir", type=str) - args = parser.parse_args() - - meeting_scp = load_scp_as_list(os.path.join(args.dir, "meeting.scp")) - vad_file = open(os.path.join(args.dir, "segments"), encoding="utf-8") - meeting2vad = {} - for one_line in vad_file: - uid, mid, st, ed = one_line.strip().split(" ") - st, ed = int(float(st) * args.sr), int(float(ed) * args.sr) - if mid not in meeting2vad: - meeting2vad[mid] = [] - meeting2vad[mid].append((uid, st, ed)) - - if not os.path.exists(args.out_dir): - os.makedirs(args.out_dir) - - task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp] - return task_list, None, args - - def post(self, results_list, args): - pass - - -def process(task_args): - _, task_list, _, args = task_args - for mid, wav_path, vad_list in task_list: - wav = librosa.load(wav_path, args.sr, True)[0] * 32767 - seg_list = [] - pos_map = [] - offset = 0 - for uid, st, ed in vad_list: - seg_list.append(wav[st: ed]) - pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st)) - offset = offset + ed - st - out = np.concatenate(seg_list, axis=0) - save_path = os.path.join(args.out_dir, "{}.wav".format(mid)) - sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) - map_path = os.path.join(args.out_dir, "{}.pos".format(mid)) - with open(map_path, "wt", encoding="utf-8") as fd: - fd.writelines(pos_map) - print(mid) - return None - - -if __name__ == '__main__': - my_runner = MyRunner(process) - my_runner.run() diff --git a/egs/mars/sd/scripts/simu_chunk_with_labels.py b/egs/mars/sd/scripts/simu_chunk_with_labels.py deleted file mode 100644 index f61b8083e..000000000 --- a/egs/mars/sd/scripts/simu_chunk_with_labels.py +++ /dev/null @@ -1,261 +0,0 @@ -import logging -import numpy as np -import soundfile -import kaldiio -from funasr.utils.job_runner import MultiProcessRunnerV3 -from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import os -import argparse -from collections import OrderedDict -import random -from typing import List, Dict -from copy import deepcopy -import json -logging.basicConfig( - level="INFO", - format=f"[{os.uname()[1].split('.')[0]}]" - f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", -) - - -class MyRunner(MultiProcessRunnerV3): - - def prepare(self, parser: argparse.ArgumentParser): - parser.add_argument("--label_scp", type=str, required=True) - parser.add_argument("--wav_scp", type=str, required=True) - parser.add_argument("--utt2spk", type=str, required=True) - parser.add_argument("--spk2meeting", type=str, required=True) - parser.add_argument("--utt2xvec", type=str, required=True) - parser.add_argument("--out_dir", type=str, required=True) - parser.add_argument("--chunk_size", type=float, default=16) - parser.add_argument("--chunk_shift", type=float, default=4) - parser.add_argument("--frame_shift", type=float, default=0.01) - parser.add_argument("--embedding_dim", type=int, default=None) - parser.add_argument("--average_emb_num", type=int, default=0) - parser.add_argument("--subset", type=int, default=0) - parser.add_argument("--data_json", type=str, default=None) - parser.add_argument("--seed", type=int, default=1234) - parser.add_argument("--log_interval", type=int, default=100) - args = parser.parse_args() - random.seed(args.seed) - np.random.seed(args.seed) - - logging.info("Loading data...") - if not os.path.exists(args.data_json): - label_list = load_scp_as_list(args.label_scp) - wav_scp = load_scp_as_dict(args.wav_scp) - utt2spk = load_scp_as_dict(args.utt2spk) - utt2xvec = load_scp_as_dict(args.utt2xvec) - spk2meeting = load_scp_as_dict(args.spk2meeting) - - meeting2spks = OrderedDict() - for spk, meeting in spk2meeting.items(): - if meeting not in meeting2spks: - meeting2spks[meeting] = [] - meeting2spks[meeting].append(spk) - - spk2utts = OrderedDict() - for utt, spk in utt2spk.items(): - if spk not in spk2utts: - spk2utts[spk] = [] - spk2utts[spk].append(utt) - - os.makedirs(os.path.dirname(args.data_json), exist_ok=True) - logging.info("Dump data...") - json.dump({ - "label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec, - "spk2utts": spk2utts, "meeting2spks": meeting2spks - }, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4) - else: - data_dict = json.load(open(args.data_json, "rt", encoding="utf-8")) - label_list = data_dict["label_list"] - wav_scp = data_dict["wav_scp"] - utt2xvec = data_dict["utt2xvec"] - spk2utts = data_dict["spk2utts"] - meeting2spks = data_dict["meeting2spks"] - - if not os.path.exists(args.out_dir): - os.makedirs(args.out_dir) - - args.chunk_size = int(args.chunk_size / args.frame_shift) - args.chunk_shift = int(args.chunk_shift / args.frame_shift) - - if args.embedding_dim is None: - args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1] - logging.info("Embedding dim is detected as {}.".format(args.embedding_dim)) - - logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format( - len(wav_scp), len(spk2utts), len(meeting2spks) - )) - return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args - - def post(self, results_list, args): - logging.info("[main]: Got {} chunks.".format(sum(results_list))) - - -def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length): - utt_list = spk2utts[spk] - wav_list = [] - cur_length = 0 - while cur_length < sample_length: - uttid = random.choice(utt_list) - wav, fs = soundfile.read(wav_scp[uttid], dtype='float32') - wav_list.append(wav) - cur_length += len(wav) - concat_wav = np.concatenate(wav_list, axis=0) - start = random.randint(0, len(concat_wav) - sample_length) - return concat_wav[start: start+sample_length] - - -def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num): - # process for dummy speaker - if spk == "None": - return np.zeros((1, embedding_dim), dtype=np.float32) - - # calculate averaged speaker embeddings - utt_list = spk2utts[spk] - if average_emb_num == 0 or average_emb_num > len(utt_list): - xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list] - else: - xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)] - xvec = np.concatenate(xvec_list, axis=0) - xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True) - xvec = np.mean(xvec, axis=0) - - return xvec - - -def simu_chunk( - frame_label: np.ndarray, - sample_label: np.ndarray, - wav_scp: Dict[str, str], - utt2xvec: Dict[str, str], - spk2utts: Dict[str, List[str]], - meeting2spks: Dict[str, List[str]], - all_speaker_list: List[str], - meeting_list: List[str], - embedding_dim: int, - average_emb_num: int, -): - frame_length, max_spk_num = frame_label.shape - sample_length = sample_label.shape[0] - positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0)) - pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)]) - - # get positive speakers - if len(pos_speaker_list) >= positive_speaker_num: - pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num) - else: - while len(pos_speaker_list) < positive_speaker_num: - _spk = random.choice(all_speaker_list) - if _spk not in pos_speaker_list: - pos_speaker_list.append(_spk) - - # get negative speakers - negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num) - neg_speaker_list = [] - while len(neg_speaker_list) < negative_speaker_num: - _spk = random.choice(all_speaker_list) - if _spk not in pos_speaker_list and _spk not in neg_speaker_list: - neg_speaker_list.append(_spk) - neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num)) - - random.shuffle(pos_speaker_list) - random.shuffle(neg_speaker_list) - seperated_wav = np.zeros(sample_label.shape, dtype=np.float32) - this_spk_list = [] - for idx, frame_num in enumerate(frame_label.sum(axis=0)): - if frame_num > 0: - spk = pos_speaker_list.pop(0) - this_spk_list.append(spk) - simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length) - seperated_wav[:, idx] = simu_spk_wav - else: - spk = neg_speaker_list.pop(0) - this_spk_list.append(spk) - - # calculate mixed wav - mixed_wav = np.sum(seperated_wav * sample_label, axis=1) - - # shuffle the order of speakers - shuffle_idx = list(range(max_spk_num)) - random.shuffle(shuffle_idx) - this_spk_list = [this_spk_list[x] for x in shuffle_idx] - seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose() - frame_label = frame_label.transpose()[shuffle_idx].transpose() - - # calculate profile - profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num) - for spk in this_spk_list] - profile = np.vstack(profile) - # pse_weights = 2 ** np.arange(max_spk_num) - # pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1) - # pse_label = pse_label.astype(str).tolist() - - return mixed_wav, seperated_wav, profile, frame_label - - -def process(task_args): - task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args - logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj)) - - out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1)) - wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) - - # out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1)) - # wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) - - out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1)) - profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) - - out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1)) - label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) - - speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys()) - - labels_list = [] - total_chunks = 0 - for org_mid, label_path in task_list: - whole_label = kaldiio.load_mat(label_path) - # random offset to keep diversity - rand_shift = random.randint(0, args.chunk_shift) - num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1 - labels_list.append((org_mid, whole_label, rand_shift, num_chunk)) - total_chunks += num_chunk - - idx = 0 - simu_chunk_count = 0 - for org_mid, whole_label, rand_shift, num_chunk in labels_list: - for i in range(num_chunk): - idx = idx + 1 - st = i * args.chunk_shift + rand_shift - ed = i * args.chunk_shift + args.chunk_size + rand_shift - utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format( - args.subset + 1, task_idx + 1, org_mid, st, ed - ) - frame_label = whole_label[st: ed, :] - sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0) - mix_wav, seg_wav, profile, frame_label = simu_chunk( - frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks, - speaker_list, meeting_list, args.embedding_dim, args.average_emb_num - ) - wav_mix_writer(utt_id, mix_wav) - # wav_sep_writer(utt_id, seg_wav) - profile_writer(utt_id, profile) - label_writer(utt_id, frame_label) - - simu_chunk_count += 1 - if simu_chunk_count % args.log_interval == 0: - logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format( - task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id)) - wav_mix_writer.close() - # wav_sep_writer.close() - profile_writer.close() - label_writer.close() - logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count)) - return simu_chunk_count - - -if __name__ == '__main__': - my_runner = MyRunner(process) - my_runner.run()