sond pipeline

This commit is contained in:
志浩 2023-02-24 11:50:42 +08:00
parent 64bd74c7be
commit 0a6ff596c6
7 changed files with 351 additions and 6 deletions

View File

@ -0,0 +1,121 @@
model: sond
model_conf:
lsm_weight: 0.0
length_normalized_loss: true
max_spk_num: 16
# speech encoder
encoder: ecapa_tdnn
encoder_conf:
# pass by model, equal to feature dim
# input_size: 80
pool_size: 20
stride: 1
speaker_encoder: conv
speaker_encoder_conf:
input_units: 256
num_layers: 3
num_units: 256
kernel_size: 1
dropout_rate: 0.0
position_encoder: null
out_units: 256
out_norm: false
auxiliary_states: false
tf2torch_tensor_name_prefix_torch: speaker_encoder
tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
ci_scorer: dot
ci_scorer_conf: {}
cd_scorer: san
cd_scorer_conf:
input_size: 512
output_size: 512
out_units: 1
attention_heads: 4
linear_units: 1024
num_blocks: 4
dropout_rate: 0.0
positional_dropout_rate: 0.0
attention_dropout_rate: 0.0
# use string "null" to remove input layer
input_layer: "null"
pos_enc_class: null
normalize_before: true
tf2torch_tensor_name_prefix_torch: cd_scorer
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
# post net
decoder: fsmn
decoder_conf:
in_units: 32
out_units: 2517
filter_size: 31
fsmn_num_layers: 6
dnn_num_layers: 1
num_memory_units: 512
ffn_inner_dim: 512
dropout_rate: 0.0
tf2torch_tensor_name_prefix_torch: decoder
tf2torch_tensor_name_prefix_tf: EAND/post_net
frontend: wav_frontend
frontend_conf:
fs: 16000
window: povey
n_mels: 80
frame_length: 25
frame_shift: 10
filter_length_min: -1
filter_length_max: -1
lfr_m: 1
lfr_n: 1
dither: 0.0
snip_edges: false
# minibatch related
batch_type: length
# 16s * 16k * 16 samples
batch_bins: 4096000
num_workers: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 50
val_scheduler_criterion:
- valid
- acc
best_model_criterion:
- - valid
- der
- min
- - valid
- forward_steps
- max
keep_nbest_models: 10
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 10000
# without spec aug
specaug: null
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2
log_interval: 50
# without normalize
normalize: None

171
egs/mars/sd/local_run.sh Executable file
View File

@ -0,0 +1,171 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="6,7"
gpu_num=2
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="." #feature output dictionary
exp_dir="."
lang=zh
dumpdir=dump/raw
feats_type=raw
token_type=char
scp=wav.scp
type=kaldi_ark
stage=3
stop_stage=4
# feature configuration
feats_dim=
sample_frequency=16000
nj=32
speed_perturb=
# exp tag
tag="exp1"
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train
valid_set=dev
test_sets="dev test"
asr_config=conf/train_asr_conformer.yaml
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pth
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
inference_nj=$[${ngpu}*${njob}]
_ngpu=1
else
inference_nj=$njob
_ngpu=0
fi
feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
# Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
asr_train.py \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
--input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
} &
done
wait
fi
# Testing Stage
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "stage 4: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
_dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
exit 0
fi
mkdir -p "${_logdir}"
_data="${feats_dir}/${dumpdir}/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
_opts=
if [ -n "${inference_config}" ]; then
_opts+="--config ${inference_config} "
fi
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
--output_dir "${_logdir}"/output.JOB \
--mode asr \
${_opts}
for f in token token_int score text; do
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/1best_recog/${f}"
done | sort -k1 >"${_dir}/${f}"
fi
done
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
python utils/proce_text.py ${_data}/text ${_data}/text.proc
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
fi

5
egs/mars/sd/path.sh Executable file
View File

@ -0,0 +1,5 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH

View File

@ -0,0 +1,45 @@
import logging
import numpy as np
import soundfile
import kaldiio
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import argparse
from collections import OrderedDict
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser: argparse.ArgumentParser):
parser.add_argument("--input_scp", type=str, required=True)
parser.add_argument("--out_path")
args = parser.parse_args()
if not os.path.exists(os.path.dirname(args.out_path)):
os.makedirs(os.path.dirname(args.out_path))
task_list = load_scp_as_list(args.input_scp)
return task_list, None, args
def post(self, result_list, args):
fd = open(args.out_path, "wt", encoding="utf-8")
for results in result_list:
for uttid, shape in results:
fd.write("{} {}\n".format(uttid, ",".join(shape)))
fd.close()
def process(task_args):
task_idx, task_list, _, args = task_args
rst = []
for uttid, file_path in task_list:
data = kaldiio.load_mat(file_path)
shape = [str(x) for x in data.shape]
rst.append((uttid, shape))
return rst
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -90,6 +90,7 @@ class DiarSondModel(AbsESPnetModel):
self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :])
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
self.inter_score_loss_weight = inter_score_loss_weight
self.forward_steps = 0
def generate_pse_embedding(self):
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
@ -123,7 +124,7 @@ class DiarSondModel(AbsESPnetModel):
"""
assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
batch_size = speech.shape[0]
self.forward_steps = self.forward_steps + 1
# 1. Network forward
pred, inter_outputs = self.prediction_forward(
speech, speech_lengths,
@ -198,6 +199,7 @@ class DiarSondModel(AbsESPnetModel):
cf=cf,
acc=acc,
der=der,
forward_steps=self.forward_steps,
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
@ -262,8 +264,10 @@ class DiarSondModel(AbsESPnetModel):
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
spk_labels: torch.Tensor = None,
spk_labels_lengths: torch.Tensor = None,
profile: torch.Tensor = None,
profile_lengths: torch.Tensor = None,
binary_labels: torch.Tensor = None,
binary_labels_lengths: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}

View File

@ -528,8 +528,6 @@ class ECAPA_TDNN(torch.nn.Module):
Arguments
---------
device : str
Device used, e.g., "cpu" or "cuda".
activation : torch class
A class for constructing the activation layers.
channels : list of ints
@ -555,7 +553,6 @@ class ECAPA_TDNN(torch.nn.Module):
def __init__(
self,
input_size,
device="cpu",
lin_neurons=192,
activation=torch.nn.ReLU,
channels=[512, 512, 512, 512, 1536],

View File

@ -24,6 +24,7 @@ from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
@ -123,6 +124,7 @@ encoder_choices = ClassChoices(
resnet34=ResNet34Diar,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
epaca_dtnn=ECAPA_TDNN,
),
type_check=AbsEncoder,
default="resnet34",