This commit is contained in:
嘉渊 2023-07-19 22:34:52 +08:00
parent 81fe1e0a09
commit f5bd371837
4 changed files with 21 additions and 15 deletions

View File

@ -54,7 +54,7 @@ if __name__ == '__main__':
parser.add_argument(
"--sampling_rate",
type=int,
default=10,
default=8000,
help="sampling rate",
)
parser.add_argument(
@ -104,7 +104,7 @@ if __name__ == '__main__':
print("Start inference")
with open(args.output_rttm_file, "w") as wf:
for wav_id in wav_items.keys():
print("Process wav: {}\n".format(wav_id))
print("Process wav: {}".format(wav_id))
data, rate = sf.read(wav_items[wav_id])
speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift)
speech = eend_ola_feature.transform(speech)

View File

@ -245,13 +245,17 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
fi
## inference
#if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# echo "Inference"
# mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
# CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
# --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
# --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
# --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
# --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
#fi
# inference and compute DER
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Inference"
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
--config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
--model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
--output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
--wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \
1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
md-eval.pl -c 0.25 \
-r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \
-s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
fi

View File

@ -245,7 +245,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
fi
# inference
# inference and compute DER
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Inference"
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
@ -255,4 +255,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
--output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
--wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \
1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
md-eval.pl -c 0.25 \
-r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \
-s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
fi

View File

@ -157,12 +157,11 @@ class DiarEENDOLAModel(FunASRModel):
def estimate_sequential(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
n_speakers: int = None,
shuffle: bool = True,
threshold: float = 0.5,
**kwargs):
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
emb = self.forward_encoder(speech, speech_lengths)
if shuffle:
orders = [np.arange(e.shape[0]) for e in emb]