mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update repo
This commit is contained in:
parent
624cad13ba
commit
81123acf88
@ -1,121 +0,0 @@
|
||||
model: sond
|
||||
model_conf:
|
||||
lsm_weight: 0.0
|
||||
length_normalized_loss: true
|
||||
max_spk_num: 16
|
||||
|
||||
# speech encoder
|
||||
encoder: ecapa_tdnn
|
||||
encoder_conf:
|
||||
# pass by model, equal to feature dim
|
||||
# input_size: 80
|
||||
pool_size: 20
|
||||
stride: 1
|
||||
speaker_encoder: conv
|
||||
speaker_encoder_conf:
|
||||
input_units: 256
|
||||
num_layers: 3
|
||||
num_units: 256
|
||||
kernel_size: 1
|
||||
dropout_rate: 0.0
|
||||
position_encoder: null
|
||||
out_units: 256
|
||||
out_norm: false
|
||||
auxiliary_states: false
|
||||
tf2torch_tensor_name_prefix_torch: speaker_encoder
|
||||
tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
|
||||
ci_scorer: dot
|
||||
ci_scorer_conf: {}
|
||||
cd_scorer: san
|
||||
cd_scorer_conf:
|
||||
input_size: 512
|
||||
output_size: 512
|
||||
out_units: 1
|
||||
attention_heads: 4
|
||||
linear_units: 1024
|
||||
num_blocks: 4
|
||||
dropout_rate: 0.0
|
||||
positional_dropout_rate: 0.0
|
||||
attention_dropout_rate: 0.0
|
||||
# use string "null" to remove input layer
|
||||
input_layer: "null"
|
||||
pos_enc_class: null
|
||||
normalize_before: true
|
||||
tf2torch_tensor_name_prefix_torch: cd_scorer
|
||||
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
|
||||
# post net
|
||||
decoder: fsmn
|
||||
decoder_conf:
|
||||
in_units: 32
|
||||
out_units: 2517
|
||||
filter_size: 31
|
||||
fsmn_num_layers: 6
|
||||
dnn_num_layers: 1
|
||||
num_memory_units: 512
|
||||
ffn_inner_dim: 512
|
||||
dropout_rate: 0.0
|
||||
tf2torch_tensor_name_prefix_torch: decoder
|
||||
tf2torch_tensor_name_prefix_tf: EAND/post_net
|
||||
frontend: wav_frontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: povey
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
filter_length_min: -1
|
||||
filter_length_max: -1
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
dither: 0.0
|
||||
snip_edges: false
|
||||
|
||||
# minibatch related
|
||||
batch_type: length
|
||||
# 16s * 16k * 16 samples
|
||||
batch_bins: 4096000
|
||||
num_workers: 8
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 50
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- acc
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- der
|
||||
- min
|
||||
- - valid
|
||||
- forward_steps
|
||||
- max
|
||||
keep_nbest_models: 10
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 10000
|
||||
|
||||
# without spec aug
|
||||
specaug: null
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
|
||||
log_interval: 50
|
||||
# without normalize
|
||||
normalize: None
|
||||
@ -1,171 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
# machines configuration
|
||||
CUDA_VISIBLE_DEVICES="6,7"
|
||||
gpu_num=2
|
||||
count=1
|
||||
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
|
||||
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
|
||||
njob=5
|
||||
train_cmd=utils/run.pl
|
||||
infer_cmd=utils/run.pl
|
||||
|
||||
# general configuration
|
||||
feats_dir="." #feature output dictionary
|
||||
exp_dir="."
|
||||
lang=zh
|
||||
dumpdir=dump/raw
|
||||
feats_type=raw
|
||||
token_type=char
|
||||
scp=wav.scp
|
||||
type=kaldi_ark
|
||||
stage=3
|
||||
stop_stage=4
|
||||
|
||||
# feature configuration
|
||||
feats_dim=
|
||||
sample_frequency=16000
|
||||
nj=32
|
||||
speed_perturb=
|
||||
|
||||
# exp tag
|
||||
tag="exp1"
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
train_set=train
|
||||
valid_set=dev
|
||||
test_sets="dev test"
|
||||
|
||||
asr_config=conf/train_asr_conformer.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
|
||||
|
||||
if ${gpu_inference}; then
|
||||
inference_nj=$[${ngpu}*${njob}]
|
||||
_ngpu=1
|
||||
else
|
||||
inference_nj=$njob
|
||||
_ngpu=0
|
||||
fi
|
||||
|
||||
feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
|
||||
feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
|
||||
feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
|
||||
|
||||
# Training Stage
|
||||
world_size=$gpu_num # run on one machine
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "stage 3: Training"
|
||||
mkdir -p ${exp_dir}/exp/${model_dir}
|
||||
mkdir -p ${exp_dir}/exp/${model_dir}/log
|
||||
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
|
||||
asr_train.py \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor true \
|
||||
--token_type char \
|
||||
--token_list $token_list \
|
||||
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
|
||||
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
|
||||
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
|
||||
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
|
||||
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
|
||||
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
|
||||
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
|
||||
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
|
||||
--resume true \
|
||||
--output_dir ${exp_dir}/exp/${model_dir} \
|
||||
--config $asr_config \
|
||||
--input_size $feats_dim \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--multiprocessing_distributed true \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
# Testing Stage
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "stage 4: Inference"
|
||||
for dset in ${test_sets}; do
|
||||
asr_exp=${exp_dir}/exp/${model_dir}
|
||||
inference_tag="$(basename "${inference_config}" .yaml)"
|
||||
_dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
if [ -d ${_dir} ]; then
|
||||
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
|
||||
exit 0
|
||||
fi
|
||||
mkdir -p "${_logdir}"
|
||||
_data="${feats_dir}/${dumpdir}/${dset}"
|
||||
key_file=${_data}/${scp}
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
_opts=
|
||||
if [ -n "${inference_config}" ]; then
|
||||
_opts+="--config ${inference_config} "
|
||||
fi
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--asr_train_config "${asr_exp}"/config.yaml \
|
||||
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode asr \
|
||||
${_opts}
|
||||
|
||||
for f in token token_int score text; do
|
||||
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
|
||||
for i in $(seq "${_nj}"); do
|
||||
cat "${_logdir}/output.${i}/1best_recog/${f}"
|
||||
done | sort -k1 >"${_dir}/${f}"
|
||||
fi
|
||||
done
|
||||
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
|
||||
python utils/proce_text.py ${_data}/text ${_data}/text.proc
|
||||
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
|
||||
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
|
||||
cat ${_dir}/text.cer.txt
|
||||
done
|
||||
fi
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
@ -1,45 +0,0 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import kaldiio
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--input_scp", type=str, required=True)
|
||||
parser.add_argument("--out_path")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(os.path.dirname(args.out_path)):
|
||||
os.makedirs(os.path.dirname(args.out_path))
|
||||
|
||||
task_list = load_scp_as_list(args.input_scp)
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, result_list, args):
|
||||
fd = open(args.out_path, "wt", encoding="utf-8")
|
||||
for results in result_list:
|
||||
for uttid, shape in results:
|
||||
fd.write("{} {}\n".format(uttid, ",".join(shape)))
|
||||
fd.close()
|
||||
|
||||
|
||||
def process(task_args):
|
||||
task_idx, task_list, _, args = task_args
|
||||
rst = []
|
||||
for uttid, file_path in task_list:
|
||||
data = kaldiio.load_mat(file_path)
|
||||
shape = [str(x) for x in data.shape]
|
||||
rst.append((uttid, shape))
|
||||
return rst
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -1,140 +0,0 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import kaldiio
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--rttm_list", type=str, required=True)
|
||||
parser.add_argument("--wav_scp_list", type=str, required=True)
|
||||
parser.add_argument("--out_dir", type=str, required=True)
|
||||
parser.add_argument("--n_spk", type=int, default=8)
|
||||
parser.add_argument("--remove_sil", default=False, action="store_true")
|
||||
parser.add_argument("--max_overlap", default=0, type=int)
|
||||
parser.add_argument("--frame_shift", type=float, default=0.01)
|
||||
args = parser.parse_args()
|
||||
|
||||
rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()]
|
||||
meeting2rttm = OrderedDict()
|
||||
for rttm_path in rttm_list:
|
||||
meeting2rttm.update(self.load_rttm(rttm_path))
|
||||
|
||||
wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()]
|
||||
meeting_scp = OrderedDict()
|
||||
for scp_path in wav_scp_list:
|
||||
meeting_scp.update(load_scp_as_dict(scp_path))
|
||||
|
||||
if len(meeting_scp) != len(meeting2rttm):
|
||||
logging.warning("Number of wav and rttm mismatch {} != {}".format(
|
||||
len(meeting_scp), len(meeting2rttm)))
|
||||
common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
|
||||
logging.warning("Keep {} records.".format(len(common_keys)))
|
||||
new_meeting_scp = OrderedDict()
|
||||
rm_keys = []
|
||||
for key in meeting_scp:
|
||||
if key not in common_keys:
|
||||
rm_keys.append(key)
|
||||
else:
|
||||
new_meeting_scp[key] = meeting_scp[key]
|
||||
logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys)))
|
||||
|
||||
new_meeting2rttm = OrderedDict()
|
||||
rm_keys = []
|
||||
for key in meeting2rttm:
|
||||
if key not in common_keys:
|
||||
rm_keys.append(key)
|
||||
else:
|
||||
new_meeting2rttm[key] = meeting2rttm[key]
|
||||
logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys)))
|
||||
meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
|
||||
return task_list, None, args
|
||||
|
||||
@staticmethod
|
||||
def load_rttm(rttm_path):
|
||||
meeting2rttm = OrderedDict()
|
||||
for one_line in open(rttm_path, "rt", encoding="utf-8"):
|
||||
mid = one_line.strip().split(" ")[1]
|
||||
if mid not in meeting2rttm:
|
||||
meeting2rttm[mid] = []
|
||||
meeting2rttm[mid].append(one_line.strip())
|
||||
|
||||
return meeting2rttm
|
||||
|
||||
def post(self, results_list, args):
|
||||
pass
|
||||
|
||||
|
||||
def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0,
|
||||
sr=None, frame_shift=0.01):
|
||||
frame_shift = int(frame_shift * sr)
|
||||
num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
|
||||
multi_label = np.zeros([n_spk, num_frame], dtype=np.float32)
|
||||
for _, st, dur, spk in spk_turns:
|
||||
idx = spk_list.index(spk)
|
||||
|
||||
st, dur = int(st * sr), int(dur * sr)
|
||||
frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
|
||||
frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
|
||||
multi_label[idx, frame_st:frame_ed] = 1
|
||||
|
||||
if remove_sil:
|
||||
speech_count = np.sum(multi_label, axis=0)
|
||||
idx = np.nonzero(speech_count)[0]
|
||||
multi_label = multi_label[:, idx]
|
||||
|
||||
if max_overlap > 0:
|
||||
speech_count = np.sum(multi_label, axis=0)
|
||||
idx = np.nonzero(speech_count <= max_overlap)[0]
|
||||
multi_label = multi_label[:, idx]
|
||||
|
||||
label = multi_label.T
|
||||
return label # (T, N)
|
||||
|
||||
|
||||
def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0,
|
||||
sr=16000, frame_shift=0.01):
|
||||
wav, sr = soundfile.read(wav_path)
|
||||
wav_len = len(wav)
|
||||
spk_turns = []
|
||||
spk_list = []
|
||||
for one_line in rttms:
|
||||
parts = one_line.strip().split(" ")
|
||||
mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
if spk not in spk_list:
|
||||
spk_list.append(spk)
|
||||
spk_turns.append((mid, st, dur, spk))
|
||||
labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift)
|
||||
return labels, spk_list
|
||||
|
||||
|
||||
def process(task_args):
|
||||
task_idx, task_list, _, args = task_args
|
||||
spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)),
|
||||
"wt", encoding="utf-8")
|
||||
out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1))
|
||||
label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
|
||||
for mid, wav_path, rttms in task_list:
|
||||
meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap,
|
||||
args.sr, args.frame_shift)
|
||||
label_writer(mid, meeting_labels)
|
||||
spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list)))
|
||||
|
||||
spk_list_writer.close()
|
||||
label_writer.close()
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -1,115 +0,0 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import argparse
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import soundfile as sf
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
def prepare(self, parser):
|
||||
assert isinstance(parser, argparse.ArgumentParser)
|
||||
parser.add_argument("wav_scp", type=str)
|
||||
parser.add_argument("rttm", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
parser.add_argument("--min_dur", type=float, default=2.0)
|
||||
parser.add_argument("--max_spk_num", type=int, default=4)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
wav_scp = load_scp_as_list(args.wav_scp)
|
||||
meeting2rttms = {}
|
||||
for one_line in open(args.rttm, "rt"):
|
||||
parts = [x for x in one_line.strip().split(" ") if x != ""]
|
||||
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
if mid not in meeting2rttms:
|
||||
meeting2rttms[mid] = []
|
||||
meeting2rttms[mid].append(one_line)
|
||||
|
||||
task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, result_list, args):
|
||||
count = [0, 0]
|
||||
for result in result_list:
|
||||
count[0] += result[0]
|
||||
count[1] += result[1]
|
||||
print("Found {} speakers, extracted {}.".format(count[1], count[0]))
|
||||
|
||||
|
||||
# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
|
||||
def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
|
||||
labels = np.zeros([max_spk_num, length], int)
|
||||
spk_list = []
|
||||
for one_line in rttms:
|
||||
parts = [x for x in one_line.strip().split(" ") if x != ""]
|
||||
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
|
||||
if spk_name.isdigit():
|
||||
spk_name = "{}_S{:03d}".format(mid, int(spk_name))
|
||||
else:
|
||||
spk_name = "{}_{}".format(mid, spk_name)
|
||||
if spk_name not in spk_list:
|
||||
spk_list.append(spk_name)
|
||||
st, dur = int(st*sr), int(dur*sr)
|
||||
idx = spk_list.index(spk_name)
|
||||
labels[idx, st:st+dur] = 1
|
||||
return labels, spk_list
|
||||
|
||||
|
||||
def get_nonoverlap_turns(multi_label, spk_list):
|
||||
turns = []
|
||||
label = np.sum(multi_label, axis=0) == 1
|
||||
spk, in_turn, st = None, False, 0
|
||||
for i in range(len(label)):
|
||||
if not in_turn and label[i]:
|
||||
st, in_turn = i, True
|
||||
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
|
||||
if in_turn:
|
||||
if not label[i]:
|
||||
in_turn = False
|
||||
turns.append([st, i, spk])
|
||||
elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
|
||||
turns.append([st, i, spk])
|
||||
st, in_turn = i, True
|
||||
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
|
||||
if in_turn:
|
||||
turns.append([st, len(label), spk])
|
||||
return turns
|
||||
|
||||
|
||||
def process(task_args):
|
||||
task_id, task_list, _, args = task_args
|
||||
spk_count = [0, 0]
|
||||
for mid, wav_path, rttms in task_list:
|
||||
wav, sr = sf.read(wav_path, dtype="int16")
|
||||
assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
|
||||
multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
|
||||
turns = get_nonoverlap_turns(multi_label, spk_list)
|
||||
extracted_spk = []
|
||||
count = 1
|
||||
for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
|
||||
if (ed - st) >= args.min_dur * args.sr:
|
||||
seg = wav[st: ed]
|
||||
save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
|
||||
if not os.path.exists(os.path.dirname(save_path)):
|
||||
os.makedirs(os.path.dirname(save_path))
|
||||
sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
count += 1
|
||||
if spk not in extracted_spk:
|
||||
extracted_spk.append(spk)
|
||||
if len(extracted_spk) != len(spk_list):
|
||||
print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
|
||||
mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
|
||||
))
|
||||
spk_count[0] += len(extracted_spk)
|
||||
spk_count[1] += len(spk_list)
|
||||
return spk_count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -1,73 +0,0 @@
|
||||
import numpy as np
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import librosa
|
||||
import argparse
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("dir", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
parser.add_argument("--n_spk", type=int, default=4)
|
||||
parser.add_argument("--remove_sil", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
meeting_scp = load_scp_as_dict(os.path.join(args.dir, "meeting.scp"))
|
||||
rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
task_list = [(mid, meeting_scp[mid], rttm_path) for mid, rttm_path in rttm_scp]
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, results_list, args):
|
||||
pass
|
||||
|
||||
|
||||
def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=16000):
|
||||
multi_label = np.zeros([n_spk, length], dtype=int)
|
||||
for _, st, dur, spk in spk_turns:
|
||||
st, dur = int(st * sr), int(dur * sr)
|
||||
idx = spk_list.index(spk)
|
||||
multi_label[idx, st:st+dur] = 1
|
||||
if not remove_sil:
|
||||
return multi_label.T
|
||||
|
||||
speech_count = np.sum(multi_label, axis=0)
|
||||
idx = np.nonzero(speech_count)[0]
|
||||
label = multi_label[:, idx].T
|
||||
return label # (T, N)
|
||||
|
||||
|
||||
def build_labels(wav_path, rttm_path, n_spk, remove_sil=False, sr=16000):
|
||||
wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr)
|
||||
spk_turns = []
|
||||
spk_list = []
|
||||
for one_line in open(rttm_path, "rt"):
|
||||
parts = one_line.strip().split(" ")
|
||||
mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
|
||||
spk = "{}_S{:03d}".format(mid, spk)
|
||||
if spk not in spk_list:
|
||||
spk_list.append(spk)
|
||||
spk_turns.append((mid, st, dur, spk))
|
||||
labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil)
|
||||
return labels
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
for mid, wav_path, rttm_path in task_list:
|
||||
meeting_labels = build_labels(wav_path, rttm_path, args.n_spk, args.remove_sil)
|
||||
save_path = os.path.join(args.out_dir, "{}.lbl".format(mid))
|
||||
np.save(save_path, meeting_labels.astype(bool))
|
||||
print(mid)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -1,53 +0,0 @@
|
||||
import numpy as np
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("wav_scp", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
parser.add_argument("--chunk_dur", type=float, default=16)
|
||||
parser.add_argument("--shift_dur", type=float, default=4)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
wav_scp = load_scp_as_list(args.wav_scp)
|
||||
return wav_scp, None, args
|
||||
|
||||
def post(self, results_list, args):
|
||||
pass
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
chunk_len, shift_len = int(args.chunk_dur * args.sr), int(args.shift_dur * args.sr)
|
||||
for mid, wav_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
|
||||
if not os.path.exists(os.path.join(args.out_dir, mid)):
|
||||
os.makedirs(os.path.join(args.out_dir, mid))
|
||||
|
||||
wav = librosa.load(wav_path, args.sr, True)[0] * 32767
|
||||
n_chunk = (len(wav) - chunk_len) // shift_len + 1
|
||||
if (len(wav) - chunk_len) % shift_len > 0:
|
||||
n_chunk += 1
|
||||
for i in range(n_chunk):
|
||||
seg = wav[i*shift_len: i*shift_len + chunk_len]
|
||||
st = int(float(i*shift_len)/args.sr * 100)
|
||||
dur = int(float(len(seg))/args.sr * 100)
|
||||
file_name = "{}_S{:04d}_{:07d}_{:07d}.wav".format(mid, i, st, st+dur)
|
||||
save_path = os.path.join(args.out_dir, mid, file_name)
|
||||
sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -1,57 +0,0 @@
|
||||
import numpy as np
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("--rttm_scp", type=str)
|
||||
parser.add_argument("--seg_file", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(os.path.dirname(args.seg_file)):
|
||||
os.makedirs(os.path.dirname(args.seg_file))
|
||||
|
||||
task_list = load_scp_as_list(args.rttm_scp)
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, results_list, args):
|
||||
with open(args.seg_file, "wt", encoding="utf-8") as fd:
|
||||
for results in results_list:
|
||||
fd.writelines(results)
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
outputs = []
|
||||
for mid, rttm_path in task_list:
|
||||
spk_turns = []
|
||||
length = 0
|
||||
for one_line in open(rttm_path, 'rt', encoding="utf-8"):
|
||||
parts = one_line.strip().split(" ")
|
||||
_, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
st, ed = int(st*100), int((st + dur)*100)
|
||||
length = ed if ed > length else length
|
||||
spk_turns.append([mid, st, ed, spk_name])
|
||||
is_sph = np.zeros((length+1, ), dtype=bool)
|
||||
for _, st, ed, _ in spk_turns:
|
||||
is_sph[st:ed] = True
|
||||
|
||||
st, in_speech = 0, False
|
||||
for i in range(length+1):
|
||||
if not in_speech and is_sph[i]:
|
||||
st, in_speech = i, True
|
||||
if in_speech and not is_sph[i]:
|
||||
in_speech = False
|
||||
outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format(
|
||||
mid, st, i, mid, float(st)/100, float(i)/100
|
||||
))
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -1,138 +0,0 @@
|
||||
import soundfile
|
||||
import kaldiio
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
import os
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import numpy as np
|
||||
import argparse
|
||||
import random
|
||||
|
||||
short_spk_list = []
|
||||
def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
|
||||
all_utts = spk2utt[spk]
|
||||
idx_list = list(range(len(all_utts)))
|
||||
random.shuffle(idx_list)
|
||||
count = 0
|
||||
utt_list = []
|
||||
for i in idx_list:
|
||||
utt_id = all_utts[i]
|
||||
utt_list.append(utt_id)
|
||||
count += int(utt2frames[utt_id])
|
||||
if count >= total_len:
|
||||
break
|
||||
if count < 300 and spk not in short_spk_list:
|
||||
print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
|
||||
short_spk_list.append(spk)
|
||||
|
||||
ivc_list = [kaldiio.load_mat(utt2ivc[utt]) for utt in utt_list]
|
||||
ivc_list = [x/np.linalg.norm(x, axis=-1) for x in ivc_list]
|
||||
ivc = np.concatenate(ivc_list, axis=0)
|
||||
ivc = np.mean(ivc, axis=0, keepdims=False)
|
||||
return ivc
|
||||
|
||||
|
||||
def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args):
|
||||
out_prefix = args.out
|
||||
|
||||
ivc_dim = 192
|
||||
win_len, win_shift = 400, 160
|
||||
label_weights = 2 ** np.array(list(range(args.n_spk)))
|
||||
wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix))
|
||||
ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix))
|
||||
label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix))
|
||||
|
||||
|
||||
frames_list = []
|
||||
chunk_size = int(args.chunk_size * args.sr)
|
||||
chunk_shift = int(args.chunk_shift * args.sr)
|
||||
for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar):
|
||||
meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32')
|
||||
num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1
|
||||
meeting_labels = np.load(labels_scp[mid])
|
||||
for i in range(num_chunk):
|
||||
st, ed = i*chunk_shift, i*chunk_shift+chunk_size
|
||||
seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100))
|
||||
wav_writer(seg_id, meeting_wav[st: ed])
|
||||
|
||||
xvec_list = []
|
||||
for spk in meeting2spk_list[mid]:
|
||||
spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000)
|
||||
xvec_list.append(spk_xvec)
|
||||
for _ in range(args.n_spk - len(xvec_list)):
|
||||
xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
|
||||
xvec = np.row_stack(xvec_list)
|
||||
ivc_writer(seg_id, xvec)
|
||||
|
||||
wav_label = meeting_labels[st:ed, :]
|
||||
frame_num = (ed-st) // win_shift
|
||||
# wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant")
|
||||
feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=np.float32)
|
||||
for i in range(frame_num):
|
||||
frame_label = wav_label[i*win_shift: (i+1)*win_shift, :]
|
||||
feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(np.float32)
|
||||
label_writer(seg_id, feat_label)
|
||||
|
||||
frames_list.append((mid, feat_label.shape[0]))
|
||||
return frames_list
|
||||
|
||||
|
||||
def calc_spk_list(rttm_path):
|
||||
spk_list = []
|
||||
for one_line in open(rttm_path, "rt"):
|
||||
parts = one_line.strip().split(" ")
|
||||
mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
|
||||
spk = "{}_S{:03d}".format(mid, spk)
|
||||
if spk not in spk_list:
|
||||
spk_list.append(spk)
|
||||
|
||||
return spk_list
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dir", required=True, type=str, default=None,
|
||||
help="feats.scp")
|
||||
parser.add_argument("--out", required=True, type=str, default=None,
|
||||
help="The prefix of dumpped files.")
|
||||
parser.add_argument("--n_spk", type=int, default=4)
|
||||
parser.add_argument("--use_lfr", default=False, action="store_true")
|
||||
parser.add_argument("--no_pbar", default=False, action="store_true")
|
||||
parser.add_argument("--sr", type=int, default=16000)
|
||||
parser.add_argument("--chunk_size", type=int, default=16)
|
||||
parser.add_argument("--chunk_shift", type=int, default=4)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(os.path.dirname(args.out)):
|
||||
os.makedirs(os.path.dirname(args.out))
|
||||
|
||||
meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp"))
|
||||
labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
|
||||
rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
|
||||
utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
|
||||
utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
|
||||
utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
|
||||
utt2frames = {}
|
||||
for uttid, wav_path in utt2wav.items():
|
||||
wav, sr = soundfile.read(wav_path, dtype="int16")
|
||||
utt2frames[uttid] = int(len(wav) / sr * 100)
|
||||
|
||||
meeting2spk_list = {}
|
||||
for mid, rttm_path in rttm_scp:
|
||||
meeting2spk_list[mid] = calc_spk_list(rttm_path)
|
||||
|
||||
spk2utt = {}
|
||||
for utt, spk in utt2spk.items():
|
||||
if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
|
||||
if spk not in spk2utt:
|
||||
spk2utt[spk] = []
|
||||
spk2utt[spk].append(utt)
|
||||
|
||||
# random.shuffle(feat_scp)
|
||||
meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args)
|
||||
total_frames = sum([x[1] for x in meeting_lens])
|
||||
print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,110 +0,0 @@
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from copy import deepcopy
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
def prepare(self, parser):
|
||||
assert isinstance(parser, argparse.ArgumentParser)
|
||||
parser.add_argument("wav_scp", type=str)
|
||||
parser.add_argument("rttm_scp", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
parser.add_argument("--min_dur", type=float, default=2.0)
|
||||
parser.add_argument("--max_spk_num", type=int, default=4)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
wav_scp = load_scp_as_list(args.wav_scp)
|
||||
rttm_scp = load_scp_as_dict(args.rttm_scp)
|
||||
task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp]
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, result_list, args):
|
||||
count = [0, 0]
|
||||
for result in result_list:
|
||||
count[0] += result[0]
|
||||
count[1] += result[1]
|
||||
print("Found {} speakers, extracted {}.".format(count[1], count[0]))
|
||||
|
||||
|
||||
# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
|
||||
def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4):
|
||||
labels = np.zeros([max_spk_num, length], int)
|
||||
spk_list = []
|
||||
for one_line in open(rttm_path, 'rt'):
|
||||
parts = one_line.strip().split(" ")
|
||||
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
if spk_name.isdigit():
|
||||
spk_name = "{}_S{:03d}".format(mid, int(spk_name))
|
||||
if spk_name not in spk_list:
|
||||
spk_list.append(spk_name)
|
||||
st, dur = int(st*sr), int(dur*sr)
|
||||
idx = spk_list.index(spk_name)
|
||||
labels[idx, st:st+dur] = 1
|
||||
return labels, spk_list
|
||||
|
||||
|
||||
def get_nonoverlap_turns(multi_label, spk_list):
|
||||
turns = []
|
||||
label = np.sum(multi_label, axis=0) == 1
|
||||
spk, in_turn, st = None, False, 0
|
||||
for i in range(len(label)):
|
||||
if not in_turn and label[i]:
|
||||
st, in_turn = i, True
|
||||
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
|
||||
if in_turn:
|
||||
if not label[i]:
|
||||
in_turn = False
|
||||
turns.append([st, i, spk])
|
||||
elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
|
||||
turns.append([st, i, spk])
|
||||
st, in_turn = i, True
|
||||
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
|
||||
if in_turn:
|
||||
turns.append([st, len(label), spk])
|
||||
return turns
|
||||
|
||||
|
||||
def process(task_args):
|
||||
task_id, task_list, _, args = task_args
|
||||
spk_count = [0, 0]
|
||||
for mid, wav_path, rttm_path in task_list:
|
||||
wav, sr = sf.read(wav_path, dtype="int16")
|
||||
assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
|
||||
multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num)
|
||||
turns = get_nonoverlap_turns(multi_label, spk_list)
|
||||
extracted_spk = []
|
||||
count = 1
|
||||
for st, ed, spk in tqdm(turns, total=len(turns), ascii=True):
|
||||
if (ed - st) >= args.min_dur * args.sr:
|
||||
seg = wav[st: ed]
|
||||
save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
|
||||
if not os.path.exists(os.path.dirname(save_path)):
|
||||
os.makedirs(os.path.dirname(save_path))
|
||||
sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
count += 1
|
||||
if spk not in extracted_spk:
|
||||
extracted_spk.append(spk)
|
||||
if len(extracted_spk) != len(spk_list):
|
||||
print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
|
||||
mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
|
||||
))
|
||||
spk_count[0] += len(extracted_spk)
|
||||
spk_count[1] += len(spk_list)
|
||||
return spk_count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -1,60 +0,0 @@
|
||||
import numpy as np
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
import argparse
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("dir", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
meeting_scp = load_scp_as_list(os.path.join(args.dir, "meeting.scp"))
|
||||
vad_file = open(os.path.join(args.dir, "segments"), encoding="utf-8")
|
||||
meeting2vad = {}
|
||||
for one_line in vad_file:
|
||||
uid, mid, st, ed = one_line.strip().split(" ")
|
||||
st, ed = int(float(st) * args.sr), int(float(ed) * args.sr)
|
||||
if mid not in meeting2vad:
|
||||
meeting2vad[mid] = []
|
||||
meeting2vad[mid].append((uid, st, ed))
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp]
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, results_list, args):
|
||||
pass
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
for mid, wav_path, vad_list in task_list:
|
||||
wav = librosa.load(wav_path, args.sr, True)[0] * 32767
|
||||
seg_list = []
|
||||
pos_map = []
|
||||
offset = 0
|
||||
for uid, st, ed in vad_list:
|
||||
seg_list.append(wav[st: ed])
|
||||
pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st))
|
||||
offset = offset + ed - st
|
||||
out = np.concatenate(seg_list, axis=0)
|
||||
save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
|
||||
sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
map_path = os.path.join(args.out_dir, "{}.pos".format(mid))
|
||||
with open(map_path, "wt", encoding="utf-8") as fd:
|
||||
fd.writelines(pos_map)
|
||||
print(mid)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -1,261 +0,0 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import kaldiio
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
import random
|
||||
from typing import List, Dict
|
||||
from copy import deepcopy
|
||||
import json
|
||||
logging.basicConfig(
|
||||
level="INFO",
|
||||
format=f"[{os.uname()[1].split('.')[0]}]"
|
||||
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--label_scp", type=str, required=True)
|
||||
parser.add_argument("--wav_scp", type=str, required=True)
|
||||
parser.add_argument("--utt2spk", type=str, required=True)
|
||||
parser.add_argument("--spk2meeting", type=str, required=True)
|
||||
parser.add_argument("--utt2xvec", type=str, required=True)
|
||||
parser.add_argument("--out_dir", type=str, required=True)
|
||||
parser.add_argument("--chunk_size", type=float, default=16)
|
||||
parser.add_argument("--chunk_shift", type=float, default=4)
|
||||
parser.add_argument("--frame_shift", type=float, default=0.01)
|
||||
parser.add_argument("--embedding_dim", type=int, default=None)
|
||||
parser.add_argument("--average_emb_num", type=int, default=0)
|
||||
parser.add_argument("--subset", type=int, default=0)
|
||||
parser.add_argument("--data_json", type=str, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
logging.info("Loading data...")
|
||||
if not os.path.exists(args.data_json):
|
||||
label_list = load_scp_as_list(args.label_scp)
|
||||
wav_scp = load_scp_as_dict(args.wav_scp)
|
||||
utt2spk = load_scp_as_dict(args.utt2spk)
|
||||
utt2xvec = load_scp_as_dict(args.utt2xvec)
|
||||
spk2meeting = load_scp_as_dict(args.spk2meeting)
|
||||
|
||||
meeting2spks = OrderedDict()
|
||||
for spk, meeting in spk2meeting.items():
|
||||
if meeting not in meeting2spks:
|
||||
meeting2spks[meeting] = []
|
||||
meeting2spks[meeting].append(spk)
|
||||
|
||||
spk2utts = OrderedDict()
|
||||
for utt, spk in utt2spk.items():
|
||||
if spk not in spk2utts:
|
||||
spk2utts[spk] = []
|
||||
spk2utts[spk].append(utt)
|
||||
|
||||
os.makedirs(os.path.dirname(args.data_json), exist_ok=True)
|
||||
logging.info("Dump data...")
|
||||
json.dump({
|
||||
"label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec,
|
||||
"spk2utts": spk2utts, "meeting2spks": meeting2spks
|
||||
}, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
|
||||
else:
|
||||
data_dict = json.load(open(args.data_json, "rt", encoding="utf-8"))
|
||||
label_list = data_dict["label_list"]
|
||||
wav_scp = data_dict["wav_scp"]
|
||||
utt2xvec = data_dict["utt2xvec"]
|
||||
spk2utts = data_dict["spk2utts"]
|
||||
meeting2spks = data_dict["meeting2spks"]
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
args.chunk_size = int(args.chunk_size / args.frame_shift)
|
||||
args.chunk_shift = int(args.chunk_shift / args.frame_shift)
|
||||
|
||||
if args.embedding_dim is None:
|
||||
args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1]
|
||||
logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
|
||||
|
||||
logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format(
|
||||
len(wav_scp), len(spk2utts), len(meeting2spks)
|
||||
))
|
||||
return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
|
||||
|
||||
def post(self, results_list, args):
|
||||
logging.info("[main]: Got {} chunks.".format(sum(results_list)))
|
||||
|
||||
|
||||
def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
|
||||
utt_list = spk2utts[spk]
|
||||
wav_list = []
|
||||
cur_length = 0
|
||||
while cur_length < sample_length:
|
||||
uttid = random.choice(utt_list)
|
||||
wav, fs = soundfile.read(wav_scp[uttid], dtype='float32')
|
||||
wav_list.append(wav)
|
||||
cur_length += len(wav)
|
||||
concat_wav = np.concatenate(wav_list, axis=0)
|
||||
start = random.randint(0, len(concat_wav) - sample_length)
|
||||
return concat_wav[start: start+sample_length]
|
||||
|
||||
|
||||
def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num):
|
||||
# process for dummy speaker
|
||||
if spk == "None":
|
||||
return np.zeros((1, embedding_dim), dtype=np.float32)
|
||||
|
||||
# calculate averaged speaker embeddings
|
||||
utt_list = spk2utts[spk]
|
||||
if average_emb_num == 0 or average_emb_num > len(utt_list):
|
||||
xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list]
|
||||
else:
|
||||
xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)]
|
||||
xvec = np.concatenate(xvec_list, axis=0)
|
||||
xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True)
|
||||
xvec = np.mean(xvec, axis=0)
|
||||
|
||||
return xvec
|
||||
|
||||
|
||||
def simu_chunk(
|
||||
frame_label: np.ndarray,
|
||||
sample_label: np.ndarray,
|
||||
wav_scp: Dict[str, str],
|
||||
utt2xvec: Dict[str, str],
|
||||
spk2utts: Dict[str, List[str]],
|
||||
meeting2spks: Dict[str, List[str]],
|
||||
all_speaker_list: List[str],
|
||||
meeting_list: List[str],
|
||||
embedding_dim: int,
|
||||
average_emb_num: int,
|
||||
):
|
||||
frame_length, max_spk_num = frame_label.shape
|
||||
sample_length = sample_label.shape[0]
|
||||
positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0))
|
||||
pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)])
|
||||
|
||||
# get positive speakers
|
||||
if len(pos_speaker_list) >= positive_speaker_num:
|
||||
pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num)
|
||||
else:
|
||||
while len(pos_speaker_list) < positive_speaker_num:
|
||||
_spk = random.choice(all_speaker_list)
|
||||
if _spk not in pos_speaker_list:
|
||||
pos_speaker_list.append(_spk)
|
||||
|
||||
# get negative speakers
|
||||
negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num)
|
||||
neg_speaker_list = []
|
||||
while len(neg_speaker_list) < negative_speaker_num:
|
||||
_spk = random.choice(all_speaker_list)
|
||||
if _spk not in pos_speaker_list and _spk not in neg_speaker_list:
|
||||
neg_speaker_list.append(_spk)
|
||||
neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num))
|
||||
|
||||
random.shuffle(pos_speaker_list)
|
||||
random.shuffle(neg_speaker_list)
|
||||
seperated_wav = np.zeros(sample_label.shape, dtype=np.float32)
|
||||
this_spk_list = []
|
||||
for idx, frame_num in enumerate(frame_label.sum(axis=0)):
|
||||
if frame_num > 0:
|
||||
spk = pos_speaker_list.pop(0)
|
||||
this_spk_list.append(spk)
|
||||
simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length)
|
||||
seperated_wav[:, idx] = simu_spk_wav
|
||||
else:
|
||||
spk = neg_speaker_list.pop(0)
|
||||
this_spk_list.append(spk)
|
||||
|
||||
# calculate mixed wav
|
||||
mixed_wav = np.sum(seperated_wav * sample_label, axis=1)
|
||||
|
||||
# shuffle the order of speakers
|
||||
shuffle_idx = list(range(max_spk_num))
|
||||
random.shuffle(shuffle_idx)
|
||||
this_spk_list = [this_spk_list[x] for x in shuffle_idx]
|
||||
seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose()
|
||||
frame_label = frame_label.transpose()[shuffle_idx].transpose()
|
||||
|
||||
# calculate profile
|
||||
profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
|
||||
for spk in this_spk_list]
|
||||
profile = np.vstack(profile)
|
||||
# pse_weights = 2 ** np.arange(max_spk_num)
|
||||
# pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1)
|
||||
# pse_label = pse_label.astype(str).tolist()
|
||||
|
||||
return mixed_wav, seperated_wav, profile, frame_label
|
||||
|
||||
|
||||
def process(task_args):
|
||||
task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
|
||||
logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj))
|
||||
|
||||
out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
|
||||
wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
|
||||
|
||||
# out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
|
||||
# wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
|
||||
|
||||
out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1))
|
||||
profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
|
||||
|
||||
out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1))
|
||||
label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
|
||||
|
||||
speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys())
|
||||
|
||||
labels_list = []
|
||||
total_chunks = 0
|
||||
for org_mid, label_path in task_list:
|
||||
whole_label = kaldiio.load_mat(label_path)
|
||||
# random offset to keep diversity
|
||||
rand_shift = random.randint(0, args.chunk_shift)
|
||||
num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1
|
||||
labels_list.append((org_mid, whole_label, rand_shift, num_chunk))
|
||||
total_chunks += num_chunk
|
||||
|
||||
idx = 0
|
||||
simu_chunk_count = 0
|
||||
for org_mid, whole_label, rand_shift, num_chunk in labels_list:
|
||||
for i in range(num_chunk):
|
||||
idx = idx + 1
|
||||
st = i * args.chunk_shift + rand_shift
|
||||
ed = i * args.chunk_shift + args.chunk_size + rand_shift
|
||||
utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format(
|
||||
args.subset + 1, task_idx + 1, org_mid, st, ed
|
||||
)
|
||||
frame_label = whole_label[st: ed, :]
|
||||
sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0)
|
||||
mix_wav, seg_wav, profile, frame_label = simu_chunk(
|
||||
frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks,
|
||||
speaker_list, meeting_list, args.embedding_dim, args.average_emb_num
|
||||
)
|
||||
wav_mix_writer(utt_id, mix_wav)
|
||||
# wav_sep_writer(utt_id, seg_wav)
|
||||
profile_writer(utt_id, profile)
|
||||
label_writer(utt_id, frame_label)
|
||||
|
||||
simu_chunk_count += 1
|
||||
if simu_chunk_count % args.log_interval == 0:
|
||||
logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format(
|
||||
task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id))
|
||||
wav_mix_writer.close()
|
||||
# wav_sep_writer.close()
|
||||
profile_writer.close()
|
||||
label_writer.close()
|
||||
logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count))
|
||||
return simu_chunk_count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
Loading…
Reference in New Issue
Block a user