mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
81fe1e0a09
commit
f5bd371837
@ -54,7 +54,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sampling_rate",
|
"--sampling_rate",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=8000,
|
||||||
help="sampling rate",
|
help="sampling rate",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -104,7 +104,7 @@ if __name__ == '__main__':
|
|||||||
print("Start inference")
|
print("Start inference")
|
||||||
with open(args.output_rttm_file, "w") as wf:
|
with open(args.output_rttm_file, "w") as wf:
|
||||||
for wav_id in wav_items.keys():
|
for wav_id in wav_items.keys():
|
||||||
print("Process wav: {}\n".format(wav_id))
|
print("Process wav: {}".format(wav_id))
|
||||||
data, rate = sf.read(wav_items[wav_id])
|
data, rate = sf.read(wav_items[wav_id])
|
||||||
speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift)
|
speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift)
|
||||||
speech = eend_ola_feature.transform(speech)
|
speech = eend_ola_feature.transform(speech)
|
||||||
|
|||||||
@ -245,13 +245,17 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|||||||
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
|
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
|
||||||
fi
|
fi
|
||||||
|
|
||||||
## inference
|
# inference and compute DER
|
||||||
#if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
# echo "Inference"
|
echo "Inference"
|
||||||
# mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
|
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
|
||||||
# CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
|
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
|
||||||
# --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
|
--config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
|
||||||
# --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
|
--model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
|
||||||
# --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
|
--output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
|
||||||
# --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
|
--wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \
|
||||||
#fi
|
1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
|
||||||
|
md-eval.pl -c 0.25 \
|
||||||
|
-r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \
|
||||||
|
-s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
|
||||||
|
fi
|
||||||
@ -245,7 +245,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|||||||
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
|
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# inference
|
# inference and compute DER
|
||||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
echo "Inference"
|
echo "Inference"
|
||||||
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
|
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
|
||||||
@ -255,4 +255,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|||||||
--output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
|
--output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
|
||||||
--wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \
|
--wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \
|
||||||
1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
|
1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
|
||||||
|
md-eval.pl -c 0.25 \
|
||||||
|
-r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \
|
||||||
|
-s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
|
||||||
fi
|
fi
|
||||||
@ -157,12 +157,11 @@ class DiarEENDOLAModel(FunASRModel):
|
|||||||
|
|
||||||
def estimate_sequential(self,
|
def estimate_sequential(self,
|
||||||
speech: torch.Tensor,
|
speech: torch.Tensor,
|
||||||
speech_lengths: torch.Tensor,
|
|
||||||
n_speakers: int = None,
|
n_speakers: int = None,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
threshold: float = 0.5,
|
threshold: float = 0.5,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
|
speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
|
||||||
emb = self.forward_encoder(speech, speech_lengths)
|
emb = self.forward_encoder(speech, speech_lengths)
|
||||||
if shuffle:
|
if shuffle:
|
||||||
orders = [np.arange(e.shape[0]) for e in emb]
|
orders = [np.arange(e.shape[0]) for e in emb]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user