Merge pull request #13 from alibaba-damo-academy/dev

Dev
This commit is contained in:
zhifu gao 2022-12-12 09:59:46 +08:00 committed by GitHub
commit 7817db2e20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 6 deletions

View File

@ -8,7 +8,7 @@ gpu_num=2
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=8
njob=1
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
@ -219,7 +219,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
fi
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--batch_size 100 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \

View File

@ -8,7 +8,7 @@ gpu_num=2
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=8
njob=1
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
@ -235,7 +235,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
fi
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--batch_size 100 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \

View File

@ -441,7 +441,7 @@ def inference(
"decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
format(length, forward_time, 100 * forward_time / (length*lfr_factor)))
for batch_id in range(len(results)):
for batch_id in range(_bs):
result = [results[batch_id][:-2]]
key = keys[batch_id]

View File

@ -31,10 +31,12 @@ class CifPredictor(nn.Module):
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
alphas = alphas * mask.transpose(-1, -2).float()
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
elif target_label is not None: