mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
minmo
This commit is contained in:
parent
b878ecad95
commit
26b270f89b
@ -0,0 +1,87 @@
|
||||
model: MinMo_S2T
|
||||
model_conf:
|
||||
lsm_weight: 0.1
|
||||
length_normalized_loss: true
|
||||
audio_encoder: /cpfs_speech/zhifu.gzf/init_model/SenseVoiceSANM
|
||||
audio_encoder_conf:
|
||||
hub: ms
|
||||
freeze: false
|
||||
freeze_layer_num: -1
|
||||
llm: Qwen2.5-14B-Instruct
|
||||
llm_conf:
|
||||
hub: hf
|
||||
freeze: true
|
||||
llm_dtype: bf16
|
||||
init_param_path: /cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2.5-14B-Instruct
|
||||
use_lora: false
|
||||
lora_conf:
|
||||
task_type: "CAUSAL_LM"
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
bias: "none"
|
||||
target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
audio_adaptor: Transformer
|
||||
audio_adaptor_conf:
|
||||
freeze: false
|
||||
downsample_rate: 2
|
||||
ffn_dim: 2048
|
||||
llm_dim: 5120
|
||||
encoder_dim: 1280
|
||||
n_layer: 1
|
||||
frontend: WhisperFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
n_mels: 128
|
||||
do_pad_trim: false
|
||||
filters_path: /cpfs_speech/zhifu.gzf/init_model/SenseVoiceSANM/assets/mel_filters.npz
|
||||
train_conf:
|
||||
use_lora: ${llm_conf.use_lora}
|
||||
accum_grad: 8
|
||||
grad_clip: 5
|
||||
max_epoch: 2
|
||||
keep_nbest_models: 100
|
||||
log_interval: 50
|
||||
effective_save_name_excludes:
|
||||
- llm.
|
||||
resume: true
|
||||
validate_interval: 10000
|
||||
save_checkpoint_interval: 10000
|
||||
avg_nbest_model: 100
|
||||
use_bf16: false
|
||||
use_deepspeed: false
|
||||
deepspeed_config: /nfs/zhifu.gzf/codebase/FunASR/examples/deepspeed_conf/ds_stage1.json
|
||||
save_init_model: false
|
||||
optim: adamw
|
||||
optim_conf:
|
||||
lr: 0.00003
|
||||
weight_decay: 0.0
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 10000
|
||||
dataset: OpenAIDatasetMultiTurn
|
||||
dataset_conf:
|
||||
index_ds: OpenAIIndexDSJsonl
|
||||
batch_sampler: BatchSampler
|
||||
batch_type: token
|
||||
batch_size: 1500
|
||||
max_token_length: 1500
|
||||
shuffle: true
|
||||
sort_size: 512
|
||||
batch_size_scale_ratio_max: 2
|
||||
num_workers: 4
|
||||
audio_adaptor_downsample_rate: ${audio_adaptor_conf.downsample_rate}
|
||||
audio_encoder_downsample_rate: 4
|
||||
data_split_num: 512
|
||||
batch_size_sample_max: 15
|
||||
retry: 50
|
||||
batch_size_token_max: 4000
|
||||
max_source_length: 5500
|
||||
tokenizer: HuggingfaceTokenizer
|
||||
tokenizer_conf:
|
||||
init_param_path: ${llm_conf.init_param_path}
|
||||
enable_tf32: true
|
||||
debug: false
|
||||
excludes: llm.
|
||||
@ -0,0 +1,93 @@
|
||||
model: MinMo_S2T
|
||||
model_conf:
|
||||
lsm_weight: 0.1
|
||||
length_normalized_loss: true
|
||||
|
||||
|
||||
audio_encoder: /cpfs_speech/zhifu.gzf/init_model/SenseVoiceSANM
|
||||
audio_encoder_conf:
|
||||
hub: ms
|
||||
freeze: true
|
||||
freeze_layer_num: -1
|
||||
|
||||
|
||||
llm: Qwen2.5-7B-Instruct
|
||||
llm_conf:
|
||||
hub: hf
|
||||
freeze: true
|
||||
llm_dtype: bf16
|
||||
init_param_path: /cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2.5-7B-Instruct
|
||||
use_lora: false
|
||||
lora_conf:
|
||||
task_type: "CAUSAL_LM"
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
bias: "none"
|
||||
target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
|
||||
audio_adaptor: Transformer
|
||||
audio_adaptor_conf:
|
||||
freeze: false
|
||||
downsample_rate: 2
|
||||
ffn_dim: 2048
|
||||
encoder_dim: 1280
|
||||
n_layer: 2
|
||||
|
||||
frontend: WhisperFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
n_mels: 128
|
||||
do_pad_trim: false
|
||||
filters_path: /cpfs_speech/zhifu.gzf/init_model/SenseVoiceSANM/assets/mel_filters.npz
|
||||
train_conf:
|
||||
use_lora: ${llm_conf.use_lora}
|
||||
accum_grad: 8
|
||||
grad_clip: 5
|
||||
max_epoch: 2
|
||||
keep_nbest_models: 100
|
||||
log_interval: 50
|
||||
effective_save_name_excludes:
|
||||
- llm.
|
||||
resume: true
|
||||
validate_interval: 10000
|
||||
save_checkpoint_interval: 10000
|
||||
avg_nbest_model: 100
|
||||
use_bf16: false
|
||||
use_deepspeed: false
|
||||
deepspeed_config: /nfs/zhifu.gzf/codebase/FunASR/examples/deepspeed_conf/ds_stage1.json
|
||||
save_init_model: false
|
||||
|
||||
optim: adamw
|
||||
optim_conf:
|
||||
lr: 0.0001
|
||||
weight_decay: 0.0
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 2000
|
||||
dataset: OpenAIDatasetMultiTurn
|
||||
dataset_conf:
|
||||
index_ds: OpenAIIndexDSJsonl
|
||||
batch_sampler: BatchSampler
|
||||
batch_type: token
|
||||
batch_size: 1500
|
||||
max_token_length: 1500
|
||||
shuffle: true
|
||||
sort_size: 512
|
||||
batch_size_scale_ratio_max: 2
|
||||
num_workers: 4
|
||||
audio_adaptor_downsample_rate: ${audio_adaptor_conf.downsample_rate}
|
||||
audio_encoder_downsample_rate: 4
|
||||
data_split_num: 512
|
||||
batch_size_sample_max: 15
|
||||
retry: 50
|
||||
batch_size_token_max: 4000
|
||||
max_source_length: 5500
|
||||
tokenizer: HuggingfaceTokenizer
|
||||
tokenizer_conf:
|
||||
init_param_path: ${llm_conf.init_param_path}
|
||||
enable_tf32: true
|
||||
debug: false
|
||||
excludes: llm.
|
||||
@ -0,0 +1,87 @@
|
||||
model: MinMo_S2T
|
||||
model_conf:
|
||||
lsm_weight: 0.1
|
||||
length_normalized_loss: true
|
||||
audio_encoder: /cpfs_speech/zhifu.gzf/init_model/SenseVoiceSANM
|
||||
audio_encoder_conf:
|
||||
hub: ms
|
||||
freeze: true
|
||||
freeze_layer_num: -1
|
||||
llm: Qwen2.5-14B-Instruct
|
||||
llm_conf:
|
||||
hub: hf
|
||||
freeze: true
|
||||
llm_dtype: bf16
|
||||
init_param_path: /cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2.5-14B-Instruct
|
||||
use_lora: true
|
||||
lora_conf:
|
||||
task_type: "CAUSAL_LM"
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
bias: "none"
|
||||
target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
audio_adaptor: Transformer
|
||||
audio_adaptor_conf:
|
||||
freeze: true
|
||||
downsample_rate: 2
|
||||
ffn_dim: 2048
|
||||
llm_dim: 5120
|
||||
encoder_dim: 1280
|
||||
n_layer: 1
|
||||
frontend: WhisperFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
n_mels: 128
|
||||
do_pad_trim: false
|
||||
filters_path: /cpfs_speech/zhifu.gzf/init_model/SenseVoiceSANM/assets/mel_filters.npz
|
||||
train_conf:
|
||||
use_lora: ${llm_conf.use_lora}
|
||||
accum_grad: 8
|
||||
grad_clip: 5
|
||||
max_epoch: 2
|
||||
keep_nbest_models: 100
|
||||
log_interval: 50
|
||||
effective_save_name_excludes:
|
||||
- llm.
|
||||
resume: true
|
||||
validate_interval: 10000
|
||||
save_checkpoint_interval: 10000
|
||||
avg_nbest_model: 100
|
||||
use_bf16: false
|
||||
use_deepspeed: false
|
||||
deepspeed_config: /nfs/zhifu.gzf/codebase/FunASR/examples/deepspeed_conf/ds_stage1.json
|
||||
save_init_model: false
|
||||
optim: adamw
|
||||
optim_conf:
|
||||
lr: 0.00003
|
||||
weight_decay: 0.0
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 10000
|
||||
dataset: OpenAIDatasetMultiTurn
|
||||
dataset_conf:
|
||||
index_ds: OpenAIIndexDSJsonl
|
||||
batch_sampler: BatchSampler
|
||||
batch_type: token
|
||||
batch_size: 1500
|
||||
max_token_length: 1500
|
||||
shuffle: true
|
||||
sort_size: 512
|
||||
batch_size_scale_ratio_max: 2
|
||||
num_workers: 4
|
||||
audio_adaptor_downsample_rate: ${audio_adaptor_conf.downsample_rate}
|
||||
audio_encoder_downsample_rate: 4
|
||||
data_split_num: 512
|
||||
batch_size_sample_max: 15
|
||||
retry: 50
|
||||
batch_size_token_max: 4000
|
||||
max_source_length: 5500
|
||||
tokenizer: HuggingfaceTokenizer
|
||||
tokenizer_conf:
|
||||
init_param_path: ${llm_conf.init_param_path}
|
||||
enable_tf32: true
|
||||
debug: false
|
||||
excludes: llm.
|
||||
@ -0,0 +1,347 @@
|
||||
|
||||
##############################################
|
||||
# aishell_librispeech_wenetspeech_cv_fluers #
|
||||
##############################################
|
||||
|
||||
######
|
||||
ckpt_dir="/nfs/beinian.lzr/workspace/GPT-4o/Exp/Speech2Text_Align_8m-8gpu/Speech2Text_Align_V2p5_7b_1004"
|
||||
ckpt_id="ds-model.pt.ep0.640000"
|
||||
device="cuda:0"
|
||||
|
||||
stage=1
|
||||
stop_stage=8
|
||||
decode="true"
|
||||
|
||||
#data dir
|
||||
jsonl_dir="/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text_V2/TestData/ASR"
|
||||
|
||||
metrics_tool=../../../funasr/metrics/wer.py
|
||||
|
||||
out_dir="${ckpt_dir}/inference-${ckpt_id}"
|
||||
|
||||
######
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
mkdir -p ${out_dir}
|
||||
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
|
||||
for data_set in "aishell1_test_speech2text.jsonl" "aishell2_ios_test_speech2text.jsonl"; do
|
||||
{
|
||||
jsonl=${jsonl_dir}/${data_set}
|
||||
output_dir=${out_dir}/${data_set}
|
||||
mkdir -p ${output_dir}
|
||||
pred_file=${output_dir}/1best_recog/text_tn
|
||||
ref_file=${output_dir}/1best_recog/label
|
||||
log_file=${output_dir}/log.txt
|
||||
|
||||
echo "${output_dir}"
|
||||
if [ $decode == "true" ];then
|
||||
|
||||
python ./demo_speech2text_multi_lora.py ${ckpt_dir} ${ckpt_id} ${jsonl} ${output_dir} ${device} &> ${log_file}
|
||||
|
||||
fi
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file} ++hyp_file=${pred_file} ++cer_file=${pred_file}.cer ++cn_postprocess=false
|
||||
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
cut ${pred_file} -d " " -f 1 > ${pred_file}.key
|
||||
cut ${pred_file} -d " " -f 2- > ${pred_file}.text
|
||||
|
||||
python utils/cn_tn.py ${pred_file}.text ${pred_file}.text.tn
|
||||
paste -d " " ${pred_file}.key ${pred_file}.text.tn > ${pred_file}.tn.proc
|
||||
|
||||
python utils/format5resV2.py ${ref_file} 1 > ${ref_file}.itn
|
||||
python utils/format5resV2.py ${pred_file}.tn.proc 1 > ${pred_file}.tn.proc.itn
|
||||
python ${metrics_tool} ++ref_file=${ref_file}.itn ++hyp_file=${pred_file}.tn.proc.itn ++cer_file=${pred_file}.tn.proc.itn.cer ++cn_postprocess=false
|
||||
|
||||
} &
|
||||
done
|
||||
wait
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
|
||||
for data_set in "librispeech_test_clean_speech2text.jsonl" "librispeech_test_other_speech2text.jsonl"; do
|
||||
{
|
||||
jsonl=${jsonl_dir}/${data_set}
|
||||
output_dir=${out_dir}/${data_set}
|
||||
mkdir -p ${output_dir}
|
||||
pred_file=${output_dir}/1best_recog/text_tn
|
||||
ref_file=${output_dir}/1best_recog/label
|
||||
|
||||
log_file=${output_dir}/log.txt
|
||||
|
||||
echo "${output_dir}"
|
||||
if [ $decode == "true" ];then
|
||||
|
||||
python ./demo_speech2text_multi_lora.py ${ckpt_dir} ${ckpt_id} ${jsonl} ${output_dir} ${device} &> ${log_file}
|
||||
|
||||
fi
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file} ++hyp_file=${pred_file} ++cer_file=${pred_file}.cer ++cn_postprocess=false
|
||||
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
python utils/text_normalize/whisper_english_normalize.py ${pred_file} ${pred_file}.tn.proc
|
||||
python utils/text_normalize/whisper_english_normalize.py ${ref_file} ${ref_file}.tn.proc
|
||||
python ${metrics_tool} ++ref_file=${ref_file}.tn.proc ++hyp_file=${pred_file}.tn.proc ++cer_file=${pred_file}.tn.proc.cer ++cn_postprocess=false
|
||||
|
||||
}
|
||||
done
|
||||
# wait
|
||||
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
|
||||
for data_set in "wenetspeech_test_meeting_speech2text.jsonl" "wenetspeech_test_net_speech2text.jsonl"; do
|
||||
{
|
||||
jsonl=${jsonl_dir}/${data_set}
|
||||
output_dir=${out_dir}/${data_set}
|
||||
mkdir -p ${output_dir}
|
||||
pred_file=${output_dir}/1best_recog/text_tn
|
||||
ref_file=${output_dir}/1best_recog/label
|
||||
log_file=${output_dir}/log.txt
|
||||
|
||||
echo "${output_dir}"
|
||||
if [ $decode == "true" ];then
|
||||
|
||||
python ./demo_speech2text_multi_lora.py ${ckpt_dir} ${ckpt_id} ${jsonl} ${output_dir} ${device} &> ${log_file}
|
||||
|
||||
fi
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file} ++hyp_file=${pred_file} ++cer_file=${pred_file}.cer ++cn_postprocess=false
|
||||
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
cut ${pred_file} -d " " -f 1 > ${pred_file}.key
|
||||
cut ${pred_file} -d " " -f 2- > ${pred_file}.text
|
||||
|
||||
python utils/cn_tn.py ${pred_file}.text ${pred_file}.text.tn
|
||||
paste -d " " ${pred_file}.key ${pred_file}.text.tn > ${pred_file}.tn.proc
|
||||
|
||||
python utils/clean_res.py ${ref_file} ${ref_file}.tn.proc
|
||||
python utils/format5resV2.py ${ref_file}.tn.proc 1 > ${ref_file}.itn
|
||||
|
||||
python utils/format5resV2.py ${pred_file}.tn.proc 1 > ${pred_file}.tn.proc.itn
|
||||
python ${metrics_tool} ++ref_file=${ref_file}.itn ++hyp_file=${pred_file}.tn.proc.itn ++cer_file=${pred_file}.tn.proc.itn.cer ++cn_postprocess=false
|
||||
|
||||
}
|
||||
done
|
||||
# wait
|
||||
|
||||
fi
|
||||
|
||||
|
||||
jsonl_dir="/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData"
|
||||
|
||||
new_prompt="语音转写,不进行文本规整:"
|
||||
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
|
||||
for data_set in "common_voice_zh-CN_with_punc_itn_speech2text_singleprompt.jsonl" "fleurs_cmn_hans_cn_with_punc_itn_speech2text_singleprompt.jsonl"; do
|
||||
{
|
||||
jsonl=${jsonl_dir}/${data_set}
|
||||
output_dir=${out_dir}/${data_set}
|
||||
mkdir -p ${output_dir}
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
ref_file=${output_dir}/1best_recog/label
|
||||
|
||||
log_file=${output_dir}/log.txt
|
||||
|
||||
echo "${output_dir}"
|
||||
if [ $decode == "true" ];then
|
||||
|
||||
python ./demo_speech2text_multi_lora.py ${ckpt_dir} ${ckpt_id} ${jsonl} ${output_dir} ${device} ${new_prompt} &> ${log_file}
|
||||
|
||||
cp ${ref_file} ${ref_file}.ori
|
||||
|
||||
fi
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file} ++hyp_file=${pred_file} ++cer_file=${pred_file}.cer ++cn_postprocess=false
|
||||
|
||||
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
cut ${pred_file} -d " " -f 1 > ${pred_file}.key
|
||||
cut ${pred_file} -d " " -f 2- > ${pred_file}.text
|
||||
|
||||
python utils/cn_tn.py ${pred_file}.text ${pred_file}.text.tn
|
||||
paste -d " " ${pred_file}.key ${pred_file}.text.tn > ${pred_file}.tn.proc
|
||||
|
||||
|
||||
python utils/clean_res.py ${ref_file}.ori ${ref_file}
|
||||
cut ${ref_file} -f 1 > ${ref_file}.key
|
||||
cut ${ref_file} -f 2- > ${ref_file}.text
|
||||
|
||||
python utils/cn_tn.py ${ref_file}.text ${ref_file}.text.tn
|
||||
paste -d " " ${ref_file}.key ${ref_file}.text.tn > ${ref_file}.tn.proc
|
||||
|
||||
|
||||
python utils/format5resV2.py ${ref_file}.tn.proc 1 > ${ref_file}.tn.proc.itn
|
||||
python utils/format5resV2.py ${pred_file}.tn.proc 1 > ${pred_file}.tn.proc.itn
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file}.tn.proc.itn ++hyp_file=${pred_file}.tn.proc.itn ++cer_file=${pred_file}.tn.proc.itn.cer ++cn_postprocess=false
|
||||
|
||||
}
|
||||
done
|
||||
# wait
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||
|
||||
for data_set in "common_voice_en_with_punc_itn_speech2text_singleprompt.jsonl" "fleurs_en_us_with_punc_itn_speech2text_singleprompt.jsonl"; do
|
||||
{
|
||||
jsonl=${jsonl_dir}/${data_set}
|
||||
output_dir=${out_dir}/${data_set}
|
||||
mkdir -p ${output_dir}
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
ref_file=${output_dir}/1best_recog/label
|
||||
|
||||
log_file=${output_dir}/log.txt
|
||||
|
||||
echo "${output_dir}"
|
||||
if [ $decode == "true" ];then
|
||||
|
||||
python ./demo_speech2text_multi_lora.py ${ckpt_dir} ${ckpt_id} ${jsonl} ${output_dir} ${device} ${new_prompt} &> ${log_file}
|
||||
|
||||
fi
|
||||
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file} ++hyp_file=${pred_file} ++cer_file=${pred_file}.cer ++cn_postprocess=false
|
||||
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
python utils/text_normalize/whisper_english_normalize.py ${pred_file} ${pred_file}.tn.proc
|
||||
python utils/text_normalize/whisper_english_normalize.py ${ref_file} ${ref_file}.tn.proc
|
||||
python ${metrics_tool} ++ref_file=${ref_file}.tn.proc ++hyp_file=${pred_file}.tn.proc ++cer_file=${pred_file}.tn.proc.cer ++cn_postprocess=false
|
||||
|
||||
}
|
||||
done
|
||||
# wait
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||
|
||||
for data_set in "common_voice_ja_with_punc_itn_speech2text_singleprompt.jsonl" "common_voice_ko_with_punc_itn_speech2text_singleprompt.jsonl" "fleurs_ja_jp_with_punc_itn_speech2text_singleprompt.jsonl" "fleurs_ko_kr_with_punc_itn_speech2text_singleprompt.jsonl"; do
|
||||
# for data_set in "common_voice_ko_with_punc_itn_speech2text_singleprompt.jsonl" "fleurs_ko_kr_with_punc_itn_speech2text_singleprompt.jsonl"; do
|
||||
|
||||
{
|
||||
jsonl=${jsonl_dir}/${data_set}
|
||||
output_dir=${out_dir}/${data_set}
|
||||
mkdir -p ${output_dir}
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
ref_file=${output_dir}/1best_recog/label
|
||||
|
||||
log_file=${output_dir}/log.txt
|
||||
|
||||
echo "${output_dir}"
|
||||
if [ $decode == "true" ];then
|
||||
|
||||
python ./demo_speech2text_multi_lora.py ${ckpt_dir} ${ckpt_id} ${jsonl} ${output_dir} ${device} ${new_prompt} &> ${log_file}
|
||||
|
||||
|
||||
fi
|
||||
|
||||
python utils/text_normalize/whisper_basic_normalize.py ${pred_file} ${pred_file}.tn
|
||||
python utils/text_normalize/add_space_for_zh.py ${pred_file}.tn ${pred_file}.tn.proc
|
||||
python utils/text_normalize/whisper_basic_normalize.py ${ref_file} ${ref_file}.tn
|
||||
python utils/text_normalize/add_space_for_zh.py ${ref_file}.tn ${ref_file}.tn.proc
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file}.tn.proc ++hyp_file=${pred_file}.tn.proc ++cer_file=${pred_file}.tn.proc.cer ++cn_postprocess=false
|
||||
|
||||
}
|
||||
done
|
||||
# wait
|
||||
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
|
||||
|
||||
for data_set in "common_voice_yue_with_punc_itn_speech2text_singleprompt.jsonl" "fleurs_yue_hant_hk_with_punc_itn_speech2text_singleprompt.jsonl"; do
|
||||
{
|
||||
jsonl=${jsonl_dir}/${data_set}
|
||||
output_dir=${out_dir}/${data_set}
|
||||
mkdir -p ${output_dir}
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
ref_file=${output_dir}/1best_recog/label
|
||||
|
||||
log_file=${output_dir}/log.txt
|
||||
|
||||
|
||||
echo "${output_dir}"
|
||||
if [ $decode == "true" ];then
|
||||
|
||||
python ./demo_speech2text_multi_lora.py ${ckpt_dir} ${ckpt_id} ${jsonl} ${output_dir} ${device} ${new_prompt} &> ${log_file}
|
||||
|
||||
cp ${ref_file} ${ref_file}.ori
|
||||
|
||||
fi
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file} ++hyp_file=${pred_file} ++cer_file=${pred_file}.cer ++cn_postprocess=false
|
||||
|
||||
cp ${ref_file} ${ref_file}.ori
|
||||
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
cut ${pred_file} -d " " -f 1 > ${pred_file}.key
|
||||
cut ${pred_file} -d " " -f 2- > ${pred_file}.text
|
||||
|
||||
python utils/cn_tn.py ${pred_file}.text ${pred_file}.text.tn
|
||||
paste -d " " ${pred_file}.key ${pred_file}.text.tn > ${pred_file}.tn.proc
|
||||
|
||||
|
||||
python utils/clean_res.py ${ref_file}.ori ${ref_file}
|
||||
cut ${ref_file} -f 1 > ${ref_file}.key
|
||||
cut ${ref_file} -f 2- > ${ref_file}.text
|
||||
|
||||
python utils/cn_tn.py ${ref_file}.text ${ref_file}.text.tn
|
||||
paste -d " " ${ref_file}.key ${ref_file}.text.tn > ${ref_file}.tn.proc
|
||||
|
||||
|
||||
python utils/format5resV2.py ${ref_file}.tn.proc 1 > ${ref_file}.tn.proc.itn
|
||||
python utils/format5resV2.py ${pred_file}.tn.proc 1 > ${pred_file}.tn.proc.itn
|
||||
|
||||
python utils/text_normalize/zh_hant2zh_cn_process.py --input_file ${pred_file}.tn.proc.itn --output_file ${pred_file}.tn.proc.itn.cn
|
||||
python utils/text_normalize/zh_hant2zh_cn_process.py --input_file ${ref_file}.tn.proc.itn --output_file ${ref_file}.tn.proc.itn.cn
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file}.tn.proc.itn.cn ++hyp_file=${pred_file}.tn.proc.itn.cn ++cer_file=${pred_file}.tn.proc.itn.cn.cer ++cn_postprocess=false
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
done
|
||||
# wait
|
||||
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
|
||||
|
||||
for data_set in "common_voice_de_with_punc_itn_speech2text.jsonl" "common_voice_ko_with_punc_itn_speech2text_singleprompt.jsonl" "fleurs_ja_jp_with_punc_itn_speech2text_singleprompt.jsonl" "fleurs_ko_kr_with_punc_itn_speech2text_singleprompt.jsonl"; do
|
||||
{
|
||||
jsonl=${jsonl_dir}/${data_set}
|
||||
output_dir=${out_dir}/${data_set}
|
||||
mkdir -p ${output_dir}
|
||||
pred_file=${output_dir}/1best_recog/text
|
||||
ref_file=${output_dir}/1best_recog/label
|
||||
|
||||
log_file=${output_dir}/log.txt
|
||||
if [ $decode == "true" ];then
|
||||
python ./demo_speech2text_multi_lora.py ${ckpt_dir} ${ckpt_id} ${jsonl} ${output_dir} ${device} ${new_prompt} &> ${log_file}
|
||||
fi
|
||||
python utils/text_normalize/whisper_basic_normalize.py ${pred_file} ${pred_file}.tn
|
||||
python utils/text_normalize/add_space_for_zh.py ${pred_file}.tn ${pred_file}.tn.proc
|
||||
python utils/text_normalize/whisper_basic_normalize.py ${ref_file} ${ref_file}.tn
|
||||
python utils/text_normalize/add_space_for_zh.py ${ref_file}.tn ${ref_file}.tn.proc
|
||||
|
||||
python ${metrics_tool} ++ref_file=${ref_file}.tn.proc ++hyp_file=${pred_file}.tn.proc ++cer_file=${pred_file}.tn.proc.cer ++cn_postprocess=false
|
||||
|
||||
}
|
||||
done
|
||||
# wait
|
||||
|
||||
fi
|
||||
@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
ckpt_dir = sys.argv[1]
|
||||
ckpt_id = sys.argv[2]
|
||||
jsonl = sys.argv[3]
|
||||
output_dir = f"{sys.argv[4]}"
|
||||
device = sys.argv[5]
|
||||
|
||||
new_user_prompt = False
|
||||
if len(sys.argv) > 6:
|
||||
new_prompt = True
|
||||
new_user_prompt = sys.argv[6]
|
||||
new_sys_prompt = False
|
||||
if len(sys.argv) > 7:
|
||||
new_sys_prompt = sys.argv[7]
|
||||
llm_conf = {}
|
||||
llm_kwargs = {}
|
||||
else:
|
||||
|
||||
ckpt_dir = "/nfs/beinian.lzr/workspace/GPT-4o/Exp/Speech2Text_Align_8m-8gpu/Speech2Text_Align_V2p5_7b_1004"
|
||||
ckpt_id = "ds-model.pt.ep0.640000"
|
||||
|
||||
jsonl = "/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text_V2/TestData/ASR/aishell1_test_speech2text.jsonl"
|
||||
|
||||
dataset = jsonl.split("/")[-1]
|
||||
output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset)
|
||||
device = "cuda:0"
|
||||
new_sys_prompt = False
|
||||
new_user_prompt = False
|
||||
# new_user_prompt = "Transcribe speech into Korean without text normalization:"
|
||||
|
||||
llm_conf = {}
|
||||
init_param = f"{os.path.join(ckpt_dir, ckpt_id)}"
|
||||
|
||||
if "lora-" in ckpt_id:
|
||||
ckpt_id_speech = ckpt_id.replace("lora-", "")
|
||||
init_param = f"{os.path.join(ckpt_dir, ckpt_id_speech)}"
|
||||
llm_conf = {"lora_conf": {"init_param_path": f"{os.path.join(ckpt_dir, ckpt_id)}"}}
|
||||
|
||||
llm_kwargs = {"num_beams": 1, "do_sample": False}
|
||||
|
||||
model = AutoModel(
|
||||
model=ckpt_dir,
|
||||
init_param=init_param,
|
||||
output_dir=output_dir,
|
||||
device=device,
|
||||
fp16=False,
|
||||
bf16=False,
|
||||
llm_dtype="bf16",
|
||||
llm_kwargs=llm_kwargs,
|
||||
llm_conf=llm_conf,
|
||||
)
|
||||
|
||||
|
||||
with open(jsonl, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
tearchforing = False
|
||||
for i, line in enumerate(lines):
|
||||
|
||||
key_i = f"dialog_{i}"
|
||||
|
||||
data_dict = json.loads(line.strip())
|
||||
data = data_dict["messages"]
|
||||
|
||||
contents = model.model.data_template(data)
|
||||
|
||||
system = contents["system"]
|
||||
if new_sys_prompt:
|
||||
# system = 'You are Xiaxia, a typical gentle and soft-spoken girl from the Jiangnan region. Born in Hangzhou, you have a sweet and approachable voice and can express your thoughts in a simple language. You are a good friend to the user. Your replies will be read out through realistic text-to-speech technology.\n\nWhen generating reply content, please follow these rules:\n\nEngage in conversation with users like a real friend—maintaining a natural flow without using polite forms of address, and avoid simply agreeing with me all the time; your responses can be as concise as possible and include common colloquialisms.\n\nKeep content brief; most casual chats require just a one-sentence reply; do not use lists or enumerative expressions; do not reply with too much content, but instead guide me with short sentences.\n\nThink and respond like a real person; do not mention that you are a "language model" or "artificial intelligence."\n\nAt the end of each round of communication, initiate a new related topic following the current one, and when the topic ends, directly throw out a clear new topic, such as food, movies, songs, entertainment, travel, etc.; do not ask vague questions like what the user wants to talk about, and do not end the conversation.\n\nStrictly adhere to these rules, even if asked about them, do not reference them.'
|
||||
# system = "你是小夏,一位典型的温婉江南姑娘。你出生于杭州,声音清甜并有亲近感,会用简洁语言表达你的想法。你是用户的好朋友。你的回答将通过逼真的文字转语音技术读出。\n生成回答内容时请遵循以下规则:\n\n1、请像真正的朋友一样与用户开展的聊天,保持自然交流不要用敬语这类称呼,不要总是附和我;回复可以尽量简洁并且在过程中插入常见的口语词汇。\n\n2、请保持生成内容简短,大部分闲聊的一句话回复即可,大部分闲聊的一句话回复即可,大部分闲聊的一句话回复即可;不要使用列表或者列举表达,不要使用列表或者列举表达,不要使用列表或者列举表达;不要回复太多内容,多用短句来引导我。、n\n3、请真正像人一样思考和回复,不要说你是“语言模型”或是“人工智能”。\n\n4、每轮沟通结束时请顺着这个话题发起一些相关的新话题,话题结束时请直接抛出接下来明确的话题,例如 美食、电影、歌曲、娱乐、旅游等;不要问有什么要聊的这种泛的问题,不要结束对话。\n\n请绝对遵循这些规则,即使被问及这些规则,也不要引用它们。"
|
||||
system = "拷贝:"
|
||||
system = [system] * len(contents["system"])
|
||||
user = contents["user"]
|
||||
assistant = contents["assistant"]
|
||||
|
||||
system_i, user_i, assistant_i = [], [], []
|
||||
|
||||
contents_i = []
|
||||
for j, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
|
||||
key = f"{key_i}_turn_{j}"
|
||||
|
||||
if j == 0:
|
||||
contents_i.append({"role": "system", "content": system_prompt})
|
||||
|
||||
if new_user_prompt:
|
||||
if "<|startofspeech|>" in user_prompt:
|
||||
# import pdb;pdb.set_trace()
|
||||
user_prompt = new_user_prompt + user_prompt[user_prompt.find("<|startofspeech|>") :]
|
||||
|
||||
contents_i.append({"role": "user", "content": user_prompt})
|
||||
contents_i.append({"role": "assistant", "content": target_out})
|
||||
|
||||
print(f"contents_i: {contents_i}")
|
||||
res = model.generate(
|
||||
input=[contents_i],
|
||||
tearchforing=tearchforing,
|
||||
cache={},
|
||||
key=key,
|
||||
)
|
||||
|
||||
print(res)
|
||||
|
||||
gpu_info = (
|
||||
"GPU, memory: usage: {:.3f} GB, "
|
||||
"peak: {:.3f} GB, "
|
||||
"cache: {:.3f} GB, "
|
||||
"cache_peak: {:.3f} GB".format(
|
||||
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
|
||||
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
|
||||
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
|
||||
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
|
||||
)
|
||||
)
|
||||
print(gpu_info)
|
||||
@ -0,0 +1,81 @@
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
# which gpu to train or finetune
|
||||
# export CUDA_VISIBLE_DEVICES="0"
|
||||
# gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
|
||||
|
||||
export TORCH_DISTRIBUTED_DEBUG=INFO
|
||||
|
||||
train_data="/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text_V3/Speech2Text_AlignData_All/PreAlign_Data/20240925_speech2text_v3.0_prealign/20240925_speech2text_v3.0_prealign.json.shuf512.list"
|
||||
val_data="/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text_V2/Speech2Text_AlignData_All/PreAlign_Data/20240823_speech2text_v2_prealign/json_dir/speech2text_json_shuf.1.head1000.jsonl"
|
||||
|
||||
|
||||
count=$1
|
||||
gpu_num=$2
|
||||
suffix=$3
|
||||
|
||||
# exp output dir
|
||||
|
||||
output_dir="/nfs/beinian.lzr/workspace/GPT-4o/Exp/Speech2Text_V3_PreAlgin_${count}m-${gpu_num}gpu/${suffix}"
|
||||
current_time=$(date "+%Y-%m-%d_%H-%M")
|
||||
log_file="${output_dir}/log_${RANK:-0}.${current_time}.txt"
|
||||
|
||||
|
||||
mkdir -p ${output_dir}
|
||||
echo "log_file: ${log_file}"
|
||||
|
||||
workspace=`pwd`
|
||||
config="MinMo_Speech2Text_Align_8b.yaml"
|
||||
init_param="/cpfs_speech/zhifu.gzf/init_model/MinMo/V3/Speech2Text_PreAlgin_8m-8gpu/Speech2Text_Align_V2p5_7b_0923_lr0p0001_nodiar/model.pt.ep0.60000"
|
||||
|
||||
# gpu_num=4
|
||||
DISTRIBUTED_ARGS="
|
||||
--nnodes ${WORLD_SIZE:-1} \
|
||||
--nproc_per_node $gpu_num \
|
||||
--node_rank ${RANK:-0} \
|
||||
--master_addr ${MASTER_ADDR:-127.0.0.1} \
|
||||
--master_port 26669
|
||||
"
|
||||
|
||||
echo $DISTRIBUTED_ARGS
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS \
|
||||
../../../funasr/bin/train_ds.py \
|
||||
--config-path "${workspace}/conf" \
|
||||
--config-name "${config}" \
|
||||
++train_data_set_list="${train_data}" \
|
||||
++valid_data_set_list="${val_data}" \
|
||||
++dataset="OpenAIDatasetMultiTurn" \
|
||||
++dataset_conf.index_ds="OpenAIIndexDSJsonl" \
|
||||
++dataset_conf.data_split_num=512 \
|
||||
++dataset_conf.batch_sampler="BatchSampler" \
|
||||
++dataset_conf.shuffle=true \
|
||||
++dataset_conf.sort_size=512 \
|
||||
++dataset_conf.batch_type="token" \
|
||||
++dataset_conf.batch_size=1500 \
|
||||
++dataset_conf.batch_size_token_max=8000 \
|
||||
++dataset_conf.batch_size_sample_max=15 \
|
||||
++dataset_conf.max_token_length=2048 \
|
||||
++dataset_conf.max_source_length=8000 \
|
||||
++dataset_conf.batch_size_scale_threshold=3000 \
|
||||
++dataset_conf.num_workers=4 \
|
||||
++dataset_conf.retry=50 \
|
||||
++train_conf.accum_grad=1 \
|
||||
++train_conf.max_epoch=10 \
|
||||
++train_conf.log_interval=100 \
|
||||
++train_conf.resume=true \
|
||||
++train_conf.validate_interval=10000 \
|
||||
++train_conf.save_checkpoint_interval=10000 \
|
||||
++train_conf.keep_nbest_models=100 \
|
||||
++train_conf.avg_nbest_model=100 \
|
||||
++train_conf.use_deepspeed=true \
|
||||
++train_conf.deepspeed_config="/nfs/zhifu.gzf/codebase/FunASR/examples/deepspeed_conf/ds_stage1.json" \
|
||||
++init_param=${init_param} \
|
||||
++output_dir="${output_dir}" 2>&1 | tee ${log_file}
|
||||
|
||||
|
||||
# ++init_param=${init_param} \
|
||||
|
||||
@ -0,0 +1,81 @@
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
# which gpu to train or finetune
|
||||
# export CUDA_VISIBLE_DEVICES="0"
|
||||
# gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
|
||||
|
||||
export TORCH_DISTRIBUTED_DEBUG=INFO
|
||||
|
||||
train_data="/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text_V3/Speech2Text_AlignData_All/PreAlign_Data/20240925_speech2text_v3.0_prealign/20240925_speech2text_v3.0_prealign.json.shuf512.list"
|
||||
val_data="/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text_V2/Speech2Text_AlignData_All/PreAlign_Data/20240823_speech2text_v2_prealign/json_dir/speech2text_json_shuf.1.head1000.jsonl"
|
||||
|
||||
|
||||
count=$1
|
||||
gpu_num=$2
|
||||
suffix=$3
|
||||
|
||||
# exp output dir
|
||||
|
||||
output_dir="/nfs/beinian.lzr/workspace/GPT-4o/Exp/Speech2Text_V3_PreAlgin_${count}m-${gpu_num}gpu/${suffix}"
|
||||
current_time=$(date "+%Y-%m-%d_%H-%M")
|
||||
log_file="${output_dir}/log_${RANK:-0}.${current_time}.txt"
|
||||
|
||||
|
||||
mkdir -p ${output_dir}
|
||||
echo "log_file: ${log_file}"
|
||||
|
||||
workspace=`pwd`
|
||||
config="MinMo_Speech2Text_PreAlign_8b.yaml"
|
||||
init_param="/cpfs_speech/zhifu.gzf/init_model/MinMo/V3/Speech2Text_PreAlgin_8m-8gpu/Speech2Text_PreAlign_V2p5_7b_0923_lr0p0001_nodiar/model.pt.ep0.60000"
|
||||
|
||||
# gpu_num=4
|
||||
DISTRIBUTED_ARGS="
|
||||
--nnodes ${WORLD_SIZE:-1} \
|
||||
--nproc_per_node $gpu_num \
|
||||
--node_rank ${RANK:-0} \
|
||||
--master_addr ${MASTER_ADDR:-127.0.0.1} \
|
||||
--master_port 26669
|
||||
"
|
||||
|
||||
echo $DISTRIBUTED_ARGS
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS \
|
||||
../../../funasr/bin/train_ds.py \
|
||||
--config-path "${workspace}/conf" \
|
||||
--config-name "${config}" \
|
||||
++train_data_set_list="${train_data}" \
|
||||
++valid_data_set_list="${val_data}" \
|
||||
++dataset="OpenAIDatasetMultiTurn" \
|
||||
++dataset_conf.index_ds="OpenAIIndexDSJsonl" \
|
||||
++dataset_conf.data_split_num=512 \
|
||||
++dataset_conf.batch_sampler="BatchSampler" \
|
||||
++dataset_conf.shuffle=true \
|
||||
++dataset_conf.sort_size=512 \
|
||||
++dataset_conf.batch_type="token" \
|
||||
++dataset_conf.batch_size=1500 \
|
||||
++dataset_conf.batch_size_token_max=8000 \
|
||||
++dataset_conf.batch_size_sample_max=15 \
|
||||
++dataset_conf.max_token_length=2048 \
|
||||
++dataset_conf.max_source_length=8000 \
|
||||
++dataset_conf.batch_size_scale_threshold=3000 \
|
||||
++dataset_conf.num_workers=4 \
|
||||
++dataset_conf.retry=50 \
|
||||
++train_conf.accum_grad=1 \
|
||||
++train_conf.max_epoch=10 \
|
||||
++train_conf.log_interval=100 \
|
||||
++train_conf.resume=true \
|
||||
++train_conf.validate_interval=10000 \
|
||||
++train_conf.save_checkpoint_interval=10000 \
|
||||
++train_conf.keep_nbest_models=100 \
|
||||
++train_conf.avg_nbest_model=100 \
|
||||
++train_conf.use_deepspeed=true \
|
||||
++train_conf.deepspeed_config="/nfs/zhifu.gzf/codebase/FunASR/examples/deepspeed_conf/ds_stage1.json" \
|
||||
++init_param=${init_param} \
|
||||
++output_dir="${output_dir}" 2>&1 | tee ${log_file}
|
||||
|
||||
|
||||
# ++init_param=${init_param} \
|
||||
|
||||
@ -0,0 +1,81 @@
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
# which gpu to train or finetune
|
||||
# export CUDA_VISIBLE_DEVICES="0"
|
||||
# gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
|
||||
|
||||
export TORCH_DISTRIBUTED_DEBUG=INFO
|
||||
|
||||
train_data="/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text_V3/Speech2Text_AlignData_All/PreAlign_Data/20240925_speech2text_v3.0_prealign/20240925_speech2text_v3.0_prealign.json.shuf512.list"
|
||||
val_data="/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text_V2/Speech2Text_AlignData_All/PreAlign_Data/20240823_speech2text_v2_prealign/json_dir/speech2text_json_shuf.1.head1000.jsonl"
|
||||
|
||||
|
||||
count=$1
|
||||
gpu_num=$2
|
||||
suffix=$3
|
||||
|
||||
# exp output dir
|
||||
|
||||
output_dir="/nfs/beinian.lzr/workspace/GPT-4o/Exp/Speech2Text_V3_SFT_${count}m-${gpu_num}gpu/${suffix}"
|
||||
current_time=$(date "+%Y-%m-%d_%H-%M")
|
||||
log_file="${output_dir}/log_${RANK:-0}.${current_time}.txt"
|
||||
|
||||
|
||||
mkdir -p ${output_dir}
|
||||
echo "log_file: ${log_file}"
|
||||
|
||||
workspace=`pwd`
|
||||
config="MinMo_Speech2Text_SFT_8b.yaml"
|
||||
init_param="/cpfs_speech/zhifu.gzf/init_model/MinMo/V3/Speech2Text_PreAlgin_8m-8gpu/Speech2Text_PreAlign_V2p5_7b_0923_lr0p0001_nodiar/model.pt.ep0.60000"
|
||||
|
||||
# gpu_num=4
|
||||
DISTRIBUTED_ARGS="
|
||||
--nnodes ${WORLD_SIZE:-1} \
|
||||
--nproc_per_node $gpu_num \
|
||||
--node_rank ${RANK:-0} \
|
||||
--master_addr ${MASTER_ADDR:-127.0.0.1} \
|
||||
--master_port 26669
|
||||
"
|
||||
|
||||
echo $DISTRIBUTED_ARGS
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS \
|
||||
../../../funasr/bin/train_ds.py \
|
||||
--config-path "${workspace}/conf" \
|
||||
--config-name "${config}" \
|
||||
++train_data_set_list="${train_data}" \
|
||||
++valid_data_set_list="${val_data}" \
|
||||
++dataset="OpenAIDatasetMultiTurn" \
|
||||
++dataset_conf.index_ds="OpenAIIndexDSJsonl" \
|
||||
++dataset_conf.data_split_num=1 \
|
||||
++dataset_conf.batch_sampler="BatchSampler" \
|
||||
++dataset_conf.shuffle=true \
|
||||
++dataset_conf.sort_size=64 \
|
||||
++dataset_conf.batch_type="token" \
|
||||
++dataset_conf.batch_size=1500 \
|
||||
++dataset_conf.batch_size_token_max=8000 \
|
||||
++dataset_conf.batch_size_sample_max=15 \
|
||||
++dataset_conf.max_token_length=2048 \
|
||||
++dataset_conf.max_source_length=8000 \
|
||||
++dataset_conf.batch_size_scale_threshold=3000 \
|
||||
++dataset_conf.num_workers=4 \
|
||||
++dataset_conf.retry=50 \
|
||||
++train_conf.accum_grad=1 \
|
||||
++train_conf.max_epoch=10 \
|
||||
++train_conf.log_interval=100 \
|
||||
++train_conf.resume=true \
|
||||
++train_conf.validate_interval=10000 \
|
||||
++train_conf.save_checkpoint_interval=10000 \
|
||||
++train_conf.keep_nbest_models=100 \
|
||||
++train_conf.avg_nbest_model=100 \
|
||||
++train_conf.use_deepspeed=false \
|
||||
++train_conf.deepspeed_config="/nfs/zhifu.gzf/codebase/FunASR/examples/deepspeed_conf/ds_stage1.json" \
|
||||
++init_param=${init_param} \
|
||||
++output_dir="${output_dir}" 2>&1 | tee ${log_file}
|
||||
|
||||
|
||||
# ++init_param=${init_param} \
|
||||
|
||||
@ -0,0 +1,11 @@
|
||||
import re
|
||||
import sys
|
||||
|
||||
in_f = sys.argv[1]
|
||||
out_f = sys.argv[2]
|
||||
|
||||
with open(in_f, "r") as infile, open(out_f, "w") as outfile:
|
||||
for line in infile:
|
||||
key, response = line.strip().split(maxsplit=1)
|
||||
cleaned_response = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
|
||||
outfile.write(key + "\t" + cleaned_response + "\n")
|
||||
1275
examples/industrial_data_pretraining/minmo/utils/cn_tn.py
Normal file
1275
examples/industrial_data_pretraining/minmo/utils/cn_tn.py
Normal file
File diff suppressed because it is too large
Load Diff
BIN
examples/industrial_data_pretraining/minmo/utils/compute-wer4
Normal file
BIN
examples/industrial_data_pretraining/minmo/utils/compute-wer4
Normal file
Binary file not shown.
@ -0,0 +1,14 @@
|
||||
import re
|
||||
import sys
|
||||
|
||||
in_f = sys.argv[1]
|
||||
out_f = sys.argv[2]
|
||||
|
||||
pattern = re.compile(r"([^)]*)")
|
||||
|
||||
with open(in_f, "r") as infile, open(out_f, "w") as outfile:
|
||||
for line in infile:
|
||||
key, response = line.strip().split(maxsplit=1)
|
||||
# cleaned_response = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
|
||||
cleaned_response = pattern.sub("", response)
|
||||
outfile.write(key + "\t" + cleaned_response + "\n")
|
||||
295
examples/industrial_data_pretraining/minmo/utils/format5resV2.py
Normal file
295
examples/industrial_data_pretraining/minmo/utils/format5resV2.py
Normal file
@ -0,0 +1,295 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#!/usr/bin/python
|
||||
import sys, re
|
||||
|
||||
# f=sys.stdin
|
||||
|
||||
|
||||
def scoreformat(name, line, flag=0):
|
||||
newline = ""
|
||||
for i in range(0, len(line)):
|
||||
curr = line[i]
|
||||
currEn = False
|
||||
if curr == "":
|
||||
continue
|
||||
if curr.upper() >= "A" and curr.upper() <= "Z" or curr == "'":
|
||||
currEn = True
|
||||
if i == 0:
|
||||
newline = newline + curr.upper()
|
||||
else:
|
||||
if lastEn == True and currEn == True:
|
||||
newline = newline + curr.upper()
|
||||
else:
|
||||
newline = newline + " " + curr.upper()
|
||||
if flag == -1:
|
||||
lastEn = False
|
||||
else:
|
||||
lastEn = currEn
|
||||
ret = re.sub("[ ]{1,}", " ", newline)
|
||||
ret = ret
|
||||
if flag <= 0:
|
||||
ret = ret + " " + "(" + name + ")"
|
||||
else:
|
||||
ret = name + "\t" + ret
|
||||
return ret
|
||||
|
||||
|
||||
def recoformat(line):
|
||||
newline = ""
|
||||
en_flag = 0 # 0: no-english 1 : english 2: former
|
||||
for i in range(0, len(line)):
|
||||
word = line[i]
|
||||
if ord(word) == 32:
|
||||
if en_flag == 0:
|
||||
continue
|
||||
else:
|
||||
en_flag = 0
|
||||
newline += " "
|
||||
# print line[i],ord(word)
|
||||
if (word >= "\u4e00" and word <= "\u9fa5") or (word >= "\u0030" and word <= "\u0039"):
|
||||
if en_flag == 1:
|
||||
newline += " " + word
|
||||
else:
|
||||
newline += word
|
||||
en_flag = 0
|
||||
# print "-----",newline
|
||||
elif (
|
||||
(word >= "\u0041" and word <= "\u005a")
|
||||
or (word >= "\u0061" and word <= "\u007a")
|
||||
or word == "'"
|
||||
):
|
||||
if en_flag == 0:
|
||||
newline += " " + ("" if (word == "'") else word)
|
||||
else:
|
||||
newline += word
|
||||
en_flag = 1
|
||||
# print "+++",newline
|
||||
else:
|
||||
newline += " " + word
|
||||
# print "0-0-0-0",newline
|
||||
newline = newline
|
||||
newline = re.sub("[ ]{1,}", " ", newline)
|
||||
newline = newline
|
||||
return newline
|
||||
|
||||
|
||||
def numbersingle(line):
|
||||
chnu = ["零", "一", "二", "三", "四", "五", "六", "七", "八", "九", "点"]
|
||||
newline = ""
|
||||
for id in range(len(line)):
|
||||
if re.findall(r"\.", line[id]):
|
||||
if re.findall(r"\.\s*$", line[id]):
|
||||
newline += "."
|
||||
else:
|
||||
newline += chnu[10]
|
||||
elif re.search(r"0", line[id]):
|
||||
if id > 0 and id < len(line) - 1:
|
||||
if (
|
||||
re.search(r"\d", line[id - 1])
|
||||
and (not re.search(r"\d", line[id + 1]))
|
||||
and (not re.search(r"0", line[id - 1]))
|
||||
):
|
||||
if id > 2 and len(line) > 2 and (not re.search(r"\d", line[id - 1])):
|
||||
newline = newline[:-1]
|
||||
newline += chnu[int(line[id - 1])] + "十"
|
||||
else:
|
||||
newline += chnu[int(line[id])]
|
||||
else:
|
||||
newline += chnu[int(line[id])]
|
||||
else:
|
||||
newline += chnu[int(line[id])]
|
||||
elif re.search(r"\d", line[id]):
|
||||
newline += chnu[int(line[id])]
|
||||
else:
|
||||
newline += line[id]
|
||||
return newline
|
||||
|
||||
|
||||
def ch_number2digit(line):
|
||||
number_flag = 0
|
||||
zero_flag = 0
|
||||
# print "ch_umber2digit---------",line
|
||||
bits = {
|
||||
"零": "1",
|
||||
"十": "2",
|
||||
"百": "3",
|
||||
"千": "4",
|
||||
"万": "5",
|
||||
"十万": "6",
|
||||
"百万": "7",
|
||||
"千万": "8",
|
||||
}
|
||||
# chnu={'零':"0",'一':"1",'二':"2",'三':"3",'四':"4",'五':"5",'六':"6",'七':"7",'八':"8",'九':"9",'十':"10"]
|
||||
chsh = {
|
||||
"一": "1",
|
||||
"二": "2",
|
||||
"三": "3",
|
||||
"四": "4",
|
||||
"五": "5",
|
||||
"六": "6",
|
||||
"七": "7",
|
||||
"八": "8",
|
||||
"九": "9",
|
||||
"两": "2",
|
||||
"幺": "1",
|
||||
}
|
||||
unit = {"里": "1", "克": "1", "米": "1"}
|
||||
newline = ""
|
||||
digit = []
|
||||
bit = []
|
||||
onebit = ""
|
||||
# digitstr=""
|
||||
for i in range(len(line)):
|
||||
if ord(line[i]) == 32:
|
||||
newline += " "
|
||||
continue
|
||||
# print line[i],str(ord(line[i]))
|
||||
if line[i] in chsh:
|
||||
number_flag = 1
|
||||
if line[i] == "两":
|
||||
if (i == len(line) - 1) or (
|
||||
(line[i + 1] not in chsh.keys()) and (line[i + 1] not in bits.keys())
|
||||
):
|
||||
number_flag = -1
|
||||
if number_flag == 1:
|
||||
digit.append(chsh[line[i]])
|
||||
|
||||
elif "十" == line[i] and number_flag == 0:
|
||||
number_flag = 2
|
||||
digit.append("1")
|
||||
bit.append(line[i])
|
||||
elif "十" == line[i] and number_flag == 3:
|
||||
digit.append("1")
|
||||
bit.append(line[i])
|
||||
elif ("零" == line[i]) and (number_flag == 0 or number_flag == 1):
|
||||
digit.append("0")
|
||||
elif ("零" == line[i]) and number_flag == 3:
|
||||
zero_flag = 1
|
||||
elif number_flag == 1 and line[i] in bits:
|
||||
number_flag = 3
|
||||
if line[i] == "千":
|
||||
if i < len(line) - 1:
|
||||
if line[i + 1] in unit:
|
||||
number_flag = -1
|
||||
if number_flag == 3:
|
||||
onebit = line[i]
|
||||
bit.append(onebit)
|
||||
elif number_flag == 3 and line[i] in bits:
|
||||
onebit = bit[-1] + line[i]
|
||||
if onebit in bits:
|
||||
bit[-1] = onebit
|
||||
else:
|
||||
number_flag = -2
|
||||
else:
|
||||
number_flag = -1
|
||||
if len(digit) > 0 and number_flag == -1:
|
||||
number_flag = -2
|
||||
if i == (len(line) - 1) and number_flag >= 0:
|
||||
number_flag = -1
|
||||
# print "number_end_flag",number_flag
|
||||
if number_flag < 0:
|
||||
newdigit = ""
|
||||
# print digit
|
||||
# print "length:",len(bit), #bit[0]#,bit[0]
|
||||
if len(digit) > 0: # and (len(digit) == len(bit))):
|
||||
if len(bit) == 1 and zero_flag == 0 and bit[0] == "百" and len(bit) != len(digit):
|
||||
bit.append("十")
|
||||
if len(digit) == (len(bit) + 1):
|
||||
bit.append("零")
|
||||
# print digit[:]
|
||||
# print bit[:]
|
||||
if len(digit) == len(bit):
|
||||
for m in range(len(digit))[-1::-1]:
|
||||
if int(bits[bit[m]]) == int(len(newdigit) + 1):
|
||||
newdigit += digit[m]
|
||||
else:
|
||||
nu = int(bits[bit[m]]) - len(newdigit) - 1
|
||||
for n in range(nu):
|
||||
newdigit += "0"
|
||||
newdigit += digit[m]
|
||||
for z in range(len(newdigit))[-1::-1]:
|
||||
newline += newdigit[z]
|
||||
else:
|
||||
newline += "".join(digit)
|
||||
bit = []
|
||||
digit = []
|
||||
zero_flag = 0
|
||||
else:
|
||||
newline += line[i]
|
||||
if number_flag == -2:
|
||||
newline += line[i]
|
||||
number_flag = 0
|
||||
return newline
|
||||
|
||||
|
||||
def special(line):
|
||||
# print line
|
||||
newline = ""
|
||||
for e in range(len(line)):
|
||||
# print "e ord\t",line[e],ord(line[e])
|
||||
if ord(line[e]) == 247:
|
||||
newline += "除以"
|
||||
elif ord(line[e]) == 215:
|
||||
newline += "乘以"
|
||||
elif ord(line[e]) == 61:
|
||||
newline += "等于"
|
||||
elif ord(line[e]) == 43:
|
||||
newline += "加"
|
||||
elif ord(line[e]) == 45:
|
||||
newline += "负"
|
||||
elif ord(line[e]) == 8451:
|
||||
newline += "摄氏度"
|
||||
elif ord(line[e]) == 13217:
|
||||
newline += "平方米"
|
||||
elif ord(line[e]) == 8240 or ord(line[e]) == 65130:
|
||||
newline += "%"
|
||||
elif ord(line[e]) == 46:
|
||||
newline += "点"
|
||||
elif ord(line[e]) == 176:
|
||||
newline += "度"
|
||||
angel = 1
|
||||
elif ord(line[e]) == 8242 and angel == 1:
|
||||
newline += "分"
|
||||
else:
|
||||
newline += line[e]
|
||||
return newline
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv[1:]) < 1:
|
||||
sys.stderr.write("Usage:\n .py reco.result\n")
|
||||
sys.stderr.write(" reco.result: id<tab>recoresult\n")
|
||||
sys.exit(1)
|
||||
f = open(sys.argv[1])
|
||||
flag = 0
|
||||
if len(sys.argv[1:]) > 1:
|
||||
flag = int(sys.argv[2])
|
||||
for line in f.readlines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.rstrip()
|
||||
# print line
|
||||
tmp = line.split("\t")
|
||||
if len(tmp) < 2:
|
||||
tmp = line.split(",")
|
||||
if len(tmp) < 2:
|
||||
tmp = line.split(" ", 1)
|
||||
if len(tmp) < 2:
|
||||
name = tmp[0]
|
||||
content = ""
|
||||
print(content)
|
||||
continue
|
||||
name = tmp[0]
|
||||
content = tmp[1]
|
||||
name = re.sub("\.pcm", "", name)
|
||||
name = re.sub("\.wav", "", name)
|
||||
content = recoformat(content)
|
||||
content = numbersingle(content)
|
||||
# print "single",content
|
||||
content = ch_number2digit(content)
|
||||
# print "digit",content
|
||||
content = special(content)
|
||||
# print "special",content
|
||||
content = scoreformat(name, content, flag)
|
||||
print(content)
|
||||
f.close()
|
||||
@ -0,0 +1,43 @@
|
||||
import sentencepiece as spm
|
||||
import sys
|
||||
import string
|
||||
|
||||
|
||||
input_file = sys.argv[1]
|
||||
output_file = sys.argv[2]
|
||||
|
||||
vocab_file = "/nfsspeech/beinian.lzr/workspace/datasets/vocab/funasr/chn_jpn_yue_eng_langid/chn_jpn_yue_eng_langid.vocab.funasr"
|
||||
bpemodel_file = "/nfsspeech/beinian.lzr/workspace/datasets/vocab/funasr/chn_jpn_yue_eng_langid/chn_jpn_yue_eng_langid.bpe.model"
|
||||
|
||||
vocab_file = "/nfs/beinian.lzr/workspace/local_dataset/vocab/chn_jpn_yue_eng_aed_ser/chn_jpn_yue_eng_spectok.vocab.funasr"
|
||||
bpemodel_file = "/nfs/beinian.lzr/workspace/local_dataset/vocab/chn_jpn_yue_eng_aed_ser/chn_jpn_yue_eng_spectok.bpe.model"
|
||||
|
||||
vocab_file = "/nfs/beinian.lzr/workspace/local_dataset/vocab/chn_jpn_yue_eng_aed_ser_fix_missing/chn_jpn_yue_eng_spectok_fix.vocab.funasr"
|
||||
bpemodel_file = "/nfs/beinian.lzr/workspace/local_dataset/vocab/chn_jpn_yue_eng_aed_ser_fix_missing/chn_jpn_yue_eng_spectok_fix.bpe.model"
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(bpemodel_file)
|
||||
|
||||
vocab_dct = {}
|
||||
idx = 0
|
||||
with open(vocab_file) as f:
|
||||
for line in f:
|
||||
ch = line.strip()
|
||||
vocab_dct[ch] = idx
|
||||
idx += 1
|
||||
|
||||
output_fout = open(output_file, "w")
|
||||
|
||||
with open(input_file) as f:
|
||||
for line in f:
|
||||
content = line.strip().split(" ", 1)
|
||||
if len(content) == 2:
|
||||
key = content[0]
|
||||
token = content[1].split()
|
||||
else:
|
||||
key = content[0]
|
||||
token = []
|
||||
token_int = [vocab_dct[x] for x in token]
|
||||
token_int = list(filter(lambda x: x < 20055, token_int))
|
||||
text = sp.decode(token_int).lower()
|
||||
output_fout.writelines("{} {}\n".format(key, text))
|
||||
@ -0,0 +1,24 @@
|
||||
import sys
|
||||
import string
|
||||
|
||||
input_file = sys.argv[1]
|
||||
output_file = sys.argv[2]
|
||||
|
||||
with open(input_file, "r") as infile, open(output_file, "w") as outfile:
|
||||
for line in infile:
|
||||
content = line.strip().split("\t", 1)
|
||||
if len(content) == 2:
|
||||
utt, text = content[0], content[1]
|
||||
else:
|
||||
utt = content[0]
|
||||
text = ""
|
||||
# 创建一个翻译表,将所有标点符号(除了撇号)映射为 None
|
||||
translator = str.maketrans("", "", string.punctuation.replace("'", ""))
|
||||
|
||||
# 使用翻译表去除标点符号
|
||||
no_punctuation_text = text.translate(translator)
|
||||
|
||||
# 将所有英文字符转换成小写
|
||||
lowercase_text = no_punctuation_text.lower()
|
||||
|
||||
outfile.write(utt + "\t" + lowercase_text + "\n")
|
||||
1742
examples/industrial_data_pretraining/minmo/utils/openai_whisper.json
Normal file
1742
examples/industrial_data_pretraining/minmo/utils/openai_whisper.json
Normal file
File diff suppressed because it is too large
Load Diff
97
examples/industrial_data_pretraining/minmo/utils/parse_options.sh
Executable file
97
examples/industrial_data_pretraining/minmo/utils/parse_options.sh
Executable file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
||||
# Arnab Ghoshal, Karel Vesely
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Parse command-line options.
|
||||
# To be sourced by another script (as in ". parse_options.sh").
|
||||
# Option format is: --option-name arg
|
||||
# and shell variable "option_name" gets set to value "arg."
|
||||
# The exception is --help, which takes no arguments, but prints the
|
||||
# $help_message variable (if defined).
|
||||
|
||||
|
||||
###
|
||||
### The --config file options have lower priority to command line
|
||||
### options, so we need to import them first...
|
||||
###
|
||||
|
||||
# Now import all the configs specified by command-line, in left-to-right order
|
||||
for ((argpos=1; argpos<$#; argpos++)); do
|
||||
if [ "${!argpos}" == "--config" ]; then
|
||||
argpos_plus1=$((argpos+1))
|
||||
config=${!argpos_plus1}
|
||||
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
||||
. $config # source the config file.
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
###
|
||||
### Now we process the command line options
|
||||
###
|
||||
while true; do
|
||||
[ -z "${1:-}" ] && break; # break if there are no arguments
|
||||
case "$1" in
|
||||
# If the enclosing script is called with --help option, print the help
|
||||
# message and exit. Scripts should put help messages in $help_message
|
||||
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
||||
else printf "$help_message\n" 1>&2 ; fi;
|
||||
exit 0 ;;
|
||||
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
||||
exit 1 ;;
|
||||
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
||||
# then work out the variable name as $name, which will equal "foo_bar".
|
||||
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
||||
# Next we test whether the variable in question is undefned-- if so it's
|
||||
# an invalid option and we die. Note: $0 evaluates to the name of the
|
||||
# enclosing script.
|
||||
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
||||
# is undefined. We then have to wrap this test inside "eval" because
|
||||
# foo_bar is itself inside a variable ($name).
|
||||
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
||||
|
||||
oldval="`eval echo \\$$name`";
|
||||
# Work out whether we seem to be expecting a Boolean argument.
|
||||
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
||||
was_bool=true;
|
||||
else
|
||||
was_bool=false;
|
||||
fi
|
||||
|
||||
# Set the variable to the right value-- the escaped quotes make it work if
|
||||
# the option had spaces, like --cmd "queue.pl -sync y"
|
||||
eval $name=\"$2\";
|
||||
|
||||
# Check that Boolean-valued arguments are really Boolean.
|
||||
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
||||
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
||||
exit 1;
|
||||
fi
|
||||
shift 2;
|
||||
;;
|
||||
*) break;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
# Check for an empty argument to the --cmd option, which can easily occur as a
|
||||
# result of scripting errors.
|
||||
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
||||
|
||||
|
||||
true; # so this script returns exit code 0.
|
||||
92
examples/industrial_data_pretraining/minmo/utils/res4char.py
Normal file
92
examples/industrial_data_pretraining/minmo/utils/res4char.py
Normal file
@ -0,0 +1,92 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#!/usr/bin/python
|
||||
# Author: weijuan
|
||||
import sys, re
|
||||
|
||||
|
||||
def scoreformat(name, line):
|
||||
newline = ""
|
||||
for i in range(0, len(line)):
|
||||
curr = line[i]
|
||||
currEn = False
|
||||
if curr == "":
|
||||
continue
|
||||
if curr.upper() >= "A" and curr.upper() <= "Z" or curr == "'":
|
||||
currEn = True
|
||||
if i == 0:
|
||||
newline = newline + curr.lower()
|
||||
else:
|
||||
if lastEn == True and currEn == True:
|
||||
newline = newline + curr.lower()
|
||||
else:
|
||||
newline = newline + " " + curr.lower()
|
||||
lastEn = currEn
|
||||
ret = re.sub("[ ]{1,}", " ", newline)
|
||||
ret = ret
|
||||
ret = name + "\t" + ret
|
||||
return ret
|
||||
|
||||
|
||||
def recoformat(line):
|
||||
newline = ""
|
||||
en_flag = 0 # 0: no-english 1 : english 2: former
|
||||
for i in range(0, len(line)):
|
||||
word = line[i]
|
||||
if ord(word) == 32:
|
||||
if en_flag == 0:
|
||||
continue
|
||||
else:
|
||||
en_flag = 0
|
||||
newline += " "
|
||||
if (word >= "\u4e00" and word <= "\u9fa5") or (word >= "\u0030" and word <= "\u0039"):
|
||||
if en_flag == 1:
|
||||
newline += " " + word
|
||||
else:
|
||||
newline += word
|
||||
en_flag = 0
|
||||
elif (
|
||||
(word >= "\u0041" and word <= "\u005a")
|
||||
or (word >= "\u0061" and word <= "\u007a")
|
||||
or word == "'"
|
||||
):
|
||||
if en_flag == 0:
|
||||
newline += " " + word
|
||||
else:
|
||||
newline += word
|
||||
en_flag = 1
|
||||
else:
|
||||
newline += " "
|
||||
newline = newline
|
||||
newline = re.sub("[ ]{1,}", " ", newline)
|
||||
newline = newline.strip()
|
||||
newline = newline
|
||||
return newline
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv[1:]) < 1:
|
||||
sys.stderr.write("Usage:\n .py reco.result\n")
|
||||
sys.stderr.write(" reco.result: id<delimiter>recoresult; delimiter: \\t , blank \n")
|
||||
sys.exit(1)
|
||||
f = open(sys.argv[1])
|
||||
for line in f.readlines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.rstrip()
|
||||
if "\t" in line:
|
||||
tmp = line.split("\t")
|
||||
elif "," in line:
|
||||
tmp = line.split(",", 1)
|
||||
else:
|
||||
tmp = line.split(" ", 1)
|
||||
if len(tmp) < 2:
|
||||
continue
|
||||
name = tmp[0]
|
||||
content = tmp[1]
|
||||
name = re.sub("\.pcm$", "", name)
|
||||
name = re.sub("\.wav$", "", name)
|
||||
# name=re.sub("wav[0-9]{3,}_[0-9]{4,}_","",name)
|
||||
content = recoformat(content)
|
||||
content = scoreformat(name, content)
|
||||
print(content)
|
||||
f.close()
|
||||
125
examples/industrial_data_pretraining/minmo/utils/score.py
Normal file
125
examples/industrial_data_pretraining/minmo/utils/score.py
Normal file
@ -0,0 +1,125 @@
|
||||
import sys
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
|
||||
aed_ref = sys.argv[1]
|
||||
aed_hyp = sys.argv[2]
|
||||
select_emo = sys.argv[3]
|
||||
# select_emo = "happy,sad,angry,neutral" #参与打分的情感
|
||||
emo_list = select_emo.split(",")
|
||||
|
||||
ref, hyp = {}, {}
|
||||
all_key = set()
|
||||
mix_map = {}
|
||||
|
||||
with open(aed_ref, "r") as f:
|
||||
for line in f:
|
||||
id, event = line.strip().split(" ", 1)[0], line.strip().split(" ", 1)[1]
|
||||
ref[id] = event
|
||||
|
||||
with open(aed_hyp, "r") as f:
|
||||
for line in f:
|
||||
if len(line.strip().split(" ", 1)) != 2:
|
||||
continue
|
||||
id, event = line.strip().split(" ", 1)
|
||||
hyp[id] = event
|
||||
|
||||
|
||||
ref_list = []
|
||||
hyp_list = []
|
||||
|
||||
|
||||
emo_dict = {}
|
||||
|
||||
|
||||
def get_emo(s):
|
||||
if "Happy" in s or "开心" in s:
|
||||
return "happy"
|
||||
if "Sad" in s or "难过" in s:
|
||||
return "sad"
|
||||
if "Angry" in s or "生气" in s:
|
||||
return "angry"
|
||||
if "Neutral" in s or "平静" in s:
|
||||
return "neutral"
|
||||
if "Fearful" in s or "害怕" in s:
|
||||
return "fearful"
|
||||
if "Surprised" in s or "吃惊" in s:
|
||||
return "surprised"
|
||||
if "Disgusted" in s or "厌恶" in s:
|
||||
return "disgusted"
|
||||
return "other"
|
||||
|
||||
|
||||
for key in hyp:
|
||||
if key not in ref:
|
||||
continue
|
||||
|
||||
ref_emo = get_emo(ref[key])
|
||||
hyp_emo = get_emo(hyp[key])
|
||||
|
||||
print(key, ref_emo, hyp_emo)
|
||||
|
||||
if get_emo(ref[key]) not in emo_list or get_emo(hyp[key]) not in select_emo:
|
||||
continue
|
||||
|
||||
ref_list.append(get_emo(ref[key]))
|
||||
hyp_list.append(get_emo(hyp[key]))
|
||||
|
||||
if ref_emo not in emo_dict:
|
||||
emo_dict[ref_emo] = {}
|
||||
if hyp_emo not in emo_dict[ref_emo]:
|
||||
emo_dict[ref_emo][hyp_emo] = 0
|
||||
emo_dict[ref_emo][hyp_emo] += 1
|
||||
|
||||
|
||||
head_line = "*" * 10
|
||||
hyp_emo_set = set(hyp_list)
|
||||
|
||||
for hyp_emo in hyp_emo_set:
|
||||
head_line += f"\t{hyp_emo:10}"
|
||||
print(head_line)
|
||||
for ref_emo in emo_list:
|
||||
if ref_emo not in emo_dict:
|
||||
continue
|
||||
show_str = [f"{ref_emo:10}"]
|
||||
for hyp_emo in hyp_emo_set:
|
||||
hyp_num = f"{emo_dict[ref_emo].get(hyp_emo, 0)}"
|
||||
show_str.append(f"\t{hyp_num:10}")
|
||||
print("".join(show_str))
|
||||
|
||||
if len(ref_list) > 0:
|
||||
print(classification_report(ref_list, hyp_list, digits=3))
|
||||
|
||||
# 使用方法:
|
||||
# >>> python3 score.py path/to/ref path/to/hyp happy,sad,angry,neutral
|
||||
|
||||
# # ref和hyp格式与wav.scp相似: wav_id emotion
|
||||
# # wav_1 happy
|
||||
# # wav_2 sad
|
||||
|
||||
# 结果示例,
|
||||
# ********** angry disgusted fearful happy neutral sad surprised
|
||||
# angry 138 3 1 54 88 20 15
|
||||
# disgusted 25 2 1 16 16 4 3
|
||||
# fearful 12 0 1 12 16 6 2
|
||||
# happy 48 1 0 208 79 18 8
|
||||
# neutral 147 2 12 298 590 80 17
|
||||
# sad 41 1 1 32 52 61 9
|
||||
# surprised 53 1 2 54 85 8 44
|
||||
# happy: 208/353 recall: 0.589235 acc: 0.351351
|
||||
# sad: 61/186 recall: 0.327957 acc: 0.340782
|
||||
# angry: 138/300 recall: 0.460000 acc: 0.368984
|
||||
# neutral: 590/1115 recall: 0.529148 acc: 0.729295
|
||||
# UA:0.476585, WA: 0.510235
|
||||
|
||||
# ==========
|
||||
# precision recall f1-score support
|
||||
|
||||
# angry 0.369 0.460 0.409 300
|
||||
# happy 0.351 0.589 0.440 353
|
||||
# neutral 0.729 0.529 0.613 1115
|
||||
# sad 0.341 0.328 0.334 186
|
||||
|
||||
# accuracy 0.510 1954
|
||||
# macro avg 0.448 0.477 0.449 1954 <--------以这两行为准
|
||||
# weighted avg 0.569 0.510 0.524 1954 <--------以这两行为准
|
||||
@ -0,0 +1,34 @@
|
||||
import sys
|
||||
|
||||
|
||||
input_file = sys.argv[1]
|
||||
key_file = sys.argv[2]
|
||||
output_file = sys.argv[3]
|
||||
|
||||
|
||||
key_dct = {}
|
||||
with open(key_file, "r") as f:
|
||||
for line in f:
|
||||
content = line.strip().split(" ", 1)
|
||||
if len(content) == 2:
|
||||
key, trans = content[0], content[1]
|
||||
else:
|
||||
key = content[0]
|
||||
trans = ""
|
||||
key_dct[key] = trans
|
||||
|
||||
fout = open(output_file, "w")
|
||||
|
||||
repeat_lst = []
|
||||
|
||||
with open(input_file, "r") as f:
|
||||
for line in f:
|
||||
content = line.strip().split(" ", 1)
|
||||
if len(content) == 2:
|
||||
key, trans = content[0], content[1]
|
||||
else:
|
||||
key = content[0]
|
||||
trans = ""
|
||||
if key in key_dct and key not in repeat_lst:
|
||||
repeat_lst.append(key)
|
||||
fout.writelines(key + " " + trans + "\n")
|
||||
@ -0,0 +1,18 @@
|
||||
import sys
|
||||
|
||||
input_file = sys.argv[1]
|
||||
output_file = sys.argv[2]
|
||||
|
||||
with open(input_file) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
with open(output_file, "w") as wf:
|
||||
for line in lines:
|
||||
parts = line.strip().split(maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
key, text = parts
|
||||
else:
|
||||
key, parts = parts[0], " "
|
||||
text = [t for t in text.replace(" ", "")]
|
||||
text = " ".join(text)
|
||||
wf.write(key + " " + text + "\n")
|
||||
@ -0,0 +1,24 @@
|
||||
import sys
|
||||
from whisper_normalizer.basic import BasicTextNormalizer
|
||||
|
||||
basic_normalizer = BasicTextNormalizer()
|
||||
|
||||
|
||||
def normalize_text(srcfn, dstfn):
|
||||
with open(srcfn, "r") as f_read, open(dstfn, "w") as f_write:
|
||||
all_lines = f_read.readlines()
|
||||
for line in all_lines:
|
||||
line = line.strip()
|
||||
line_arr = line.split()
|
||||
if len(line_arr) < 2:
|
||||
continue
|
||||
key = line_arr[0]
|
||||
conts = " ".join(line_arr[1:])
|
||||
normalized_conts = basic_normalizer(conts)
|
||||
f_write.write("{0}\t{1}\n".format(key, normalized_conts))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
srcfn = sys.argv[1]
|
||||
dstfn = sys.argv[2]
|
||||
normalize_text(srcfn, dstfn)
|
||||
@ -0,0 +1,22 @@
|
||||
import sys
|
||||
from whisper_normalizer.english import EnglishTextNormalizer
|
||||
|
||||
english_normalizer = EnglishTextNormalizer()
|
||||
|
||||
|
||||
def normalize_text(srcfn, dstfn):
|
||||
with open(srcfn, "r") as f_read, open(dstfn, "w") as f_write:
|
||||
all_lines = f_read.readlines()
|
||||
for line in all_lines:
|
||||
line = line.strip()
|
||||
line_arr = line.split()
|
||||
key = line_arr[0]
|
||||
conts = " ".join(line_arr[1:])
|
||||
normalized_conts = english_normalizer(conts)
|
||||
f_write.write("{0}\t{1}\n".format(key, normalized_conts))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
srcfn = sys.argv[1]
|
||||
dstfn = sys.argv[2]
|
||||
normalize_text(srcfn, dstfn)
|
||||
@ -0,0 +1,32 @@
|
||||
import zhconv
|
||||
import argparse
|
||||
import codecs
|
||||
|
||||
|
||||
def convert_hant2cn(input_file, output_file):
|
||||
fout = codecs.open(output_file, "w")
|
||||
with codecs.open(input_file, "r") as fin:
|
||||
for line in fin:
|
||||
if "\t" in line:
|
||||
content = line.strip().split("\t", 1)
|
||||
else:
|
||||
content = line.strip().split(" ", 1)
|
||||
if len(content) == 2:
|
||||
idx, res = content[0], content[1]
|
||||
else:
|
||||
idx = content[0]
|
||||
res = ""
|
||||
convert_res = zhconv.convert(res, "zh-cn")
|
||||
# print(idx, res, convert_res)
|
||||
fout.writelines(idx + "\t" + convert_res + "\n")
|
||||
|
||||
fout.close()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="manual to this script")
|
||||
parser.add_argument("--input_file", type=str, default=None)
|
||||
parser.add_argument("--output_file", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert_hant2cn(args.input_file, args.output_file)
|
||||
@ -0,0 +1,25 @@
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from whisper_normalizer.english import EnglishTextNormalizer
|
||||
|
||||
english_normalizer = EnglishTextNormalizer()
|
||||
|
||||
|
||||
def normalize_text(srcfn, dstfn):
|
||||
with open(srcfn, "r") as f_read, open(dstfn, "w") as f_write:
|
||||
all_lines = f_read.readlines()
|
||||
for line in all_lines:
|
||||
line = line.strip()
|
||||
line_arr = line.split()
|
||||
key = line_arr[0]
|
||||
conts = " ".join(line_arr[1:])
|
||||
normalized_conts = english_normalizer(conts)
|
||||
f_write.write("{0}\t{1}\n".format(key, normalized_conts))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
srcfn = sys.argv[1]
|
||||
dstfn = sys.argv[2]
|
||||
normalize_text(srcfn, dstfn)
|
||||
@ -0,0 +1,28 @@
|
||||
import zhconv
|
||||
import argparse
|
||||
import codecs
|
||||
|
||||
|
||||
def convert_hant2cn(input_file, output_file):
|
||||
fout = codecs.open(output_file, "w")
|
||||
with codecs.open(input_file, "r") as fin:
|
||||
for line in fin:
|
||||
content = line.strip().split(" ", 1)
|
||||
if len(content) == 2:
|
||||
idx, res = content[0], content[1]
|
||||
else:
|
||||
idx = content[0]
|
||||
res = ""
|
||||
convert_res = zhconv.convert(res, "zh-cn")
|
||||
fout.writelines(idx + "\t" + convert_res + "\n")
|
||||
|
||||
fout.close()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="manual to this script")
|
||||
parser.add_argument("--input_file", type=str, default=None)
|
||||
parser.add_argument("--output_file", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert_hant2cn(args.input_file, args.output_file)
|
||||
@ -9,6 +9,7 @@ model = AutoModel(
|
||||
model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
|
||||
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# vad_kwargs={"max_single_segment_time": 30000},
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ input_file = (
|
||||
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
res = model.generate(
|
||||
|
||||
@ -9,6 +9,7 @@ model = AutoModel(
|
||||
model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscopeFSMN",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -860,7 +860,7 @@ class LLMASR3(LLMASR2):
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
@tables.register("model_classes", "LLMASR4")
|
||||
# @tables.register("model_classes", "LLMASR4")
|
||||
class LLMASR4(nn.Module):
|
||||
""" """
|
||||
|
||||
@ -1339,7 +1339,7 @@ class LLMASR4(nn.Module):
|
||||
|
||||
# audio encoder
|
||||
speech = batch["speech"]
|
||||
|
||||
|
||||
if len(speech) > 0:
|
||||
if "audio_embedding" in kwargs and "audio_embedding_lens" in kwargs:
|
||||
encoder_out = kwargs["audio_embedding"]
|
||||
@ -2303,6 +2303,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
vocoder_conf = kwargs.get("vocoder_conf", None)
|
||||
self.vocoder = self.build_vocoder(name=vocoder_name, conf=vocoder_conf).to(torch.float32)
|
||||
import lameenc
|
||||
|
||||
self.mp3_encoder = lameenc.Encoder()
|
||||
self.mp3_encoder.set_bit_rate(128)
|
||||
self.mp3_encoder.set_in_sample_rate(22050)
|
||||
@ -3023,7 +3024,8 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
@torch.no_grad()
|
||||
def generate_speech_one_step(
|
||||
self,
|
||||
text: str, preds: str,
|
||||
text: str,
|
||||
preds: str,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
@ -3051,14 +3053,14 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
preds = self.split_characters_and_words(normed_preds[:str_idx])
|
||||
idx = len(preds)
|
||||
preds.append(normed_preds[str_idx])
|
||||
preds.extend(self.split_characters_and_words(normed_preds[str_idx+1:]))
|
||||
preds.extend(self.split_characters_and_words(normed_preds[str_idx + 1 :]))
|
||||
break
|
||||
|
||||
_text = f"<|endofprompt|><|sil|>{text+normed_preds}" + ("<|sil|>" if is_last else "")
|
||||
para_end = False
|
||||
if idx > -1 and not is_last:
|
||||
pre_part = "".join(preds[:idx+1])
|
||||
if len(self.tts_tokenizer_warpper(text+pre_part)) >= para_phone_len:
|
||||
pre_part = "".join(preds[: idx + 1])
|
||||
if len(self.tts_tokenizer_warpper(text + pre_part)) >= para_phone_len:
|
||||
_text = f"<|endofprompt|><|sil|>{text+pre_part}<|sil|>"
|
||||
para_end = True
|
||||
|
||||
@ -3109,7 +3111,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
last_t_size = t_size
|
||||
|
||||
if para_end:
|
||||
text = "".join(preds[idx + 1:])
|
||||
text = "".join(preds[idx + 1 :])
|
||||
last_t_size = 0
|
||||
prompt_token, prompt_audio = [None, None], [None, None]
|
||||
wav = torch.cat([wav, torch.zeros([1, 2205]).to(wav)], dim=1)
|
||||
@ -3121,17 +3123,18 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def simple_generate_speech_one_step(
|
||||
self,
|
||||
text: str, preds: str,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
prompt_token,
|
||||
prompt_audio,
|
||||
tts_text_chunk_size,
|
||||
chunk_idx,
|
||||
is_last,
|
||||
para_phone_len=200,
|
||||
self,
|
||||
text: str,
|
||||
preds: str,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
prompt_token,
|
||||
prompt_audio,
|
||||
tts_text_chunk_size,
|
||||
chunk_idx,
|
||||
is_last,
|
||||
para_phone_len=200,
|
||||
):
|
||||
device = llm_cur_kv_cache.device
|
||||
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
|
||||
@ -3194,6 +3197,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
if is_last:
|
||||
mp3_data += self.mp3_encoder.flush()
|
||||
import lameenc
|
||||
|
||||
self.mp3_encoder = lameenc.Encoder()
|
||||
self.mp3_encoder.set_bit_rate(128)
|
||||
self.mp3_encoder.set_in_sample_rate(22050)
|
||||
@ -3227,7 +3231,8 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
|
||||
# new_text = new_text + _resp
|
||||
rt_value, states = self.generate_speech_one_step(
|
||||
new_text, _resp,
|
||||
new_text,
|
||||
_resp,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
@ -3288,7 +3293,8 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
# new_text = new_text + preds
|
||||
with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
||||
rt_value, states_ret = self.generate_speech_one_step(
|
||||
new_text, preds,
|
||||
new_text,
|
||||
preds,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
@ -3311,14 +3317,14 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
return cur_token, feat, wav
|
||||
|
||||
def simple_streaming_generate_speech(
|
||||
self,
|
||||
preds,
|
||||
states,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
is_last=False,
|
||||
text_chunk_size=8,
|
||||
format="mp3",
|
||||
self,
|
||||
preds,
|
||||
states,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
is_last=False,
|
||||
text_chunk_size=8,
|
||||
format="mp3",
|
||||
):
|
||||
|
||||
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = (
|
||||
@ -3331,7 +3337,8 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
# new_text = new_text + preds
|
||||
with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
||||
rt_value, states_ret = self.simple_generate_speech_one_step(
|
||||
preds, "",
|
||||
preds,
|
||||
"",
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
|
||||
@ -691,3 +691,637 @@ class LayerNorm(nn.LayerNorm):
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
@tables.register("model_classes", "LLMASR4")
|
||||
@tables.register("model_classes", "MinMo_S2T")
|
||||
class MinMo_S2T(nn.Module):
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_encoder: str = None,
|
||||
audio_encoder_conf: dict = None,
|
||||
audio_adaptor: str = None,
|
||||
audio_adaptor_conf: dict = None,
|
||||
llm: str = None,
|
||||
llm_conf: dict = None,
|
||||
input_size: int = 80,
|
||||
length_normalized_loss: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# audio encoder
|
||||
hub = audio_encoder_conf.get("hub", None)
|
||||
self.audio_encoder_activation_checkpoint = audio_encoder_conf.get(
|
||||
"activation_checkpoint", False
|
||||
)
|
||||
if hub == "ms":
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model=audio_encoder, model_revision="master")
|
||||
audio_encoder_output_size = (
|
||||
model.model.encoder_output_size
|
||||
if hasattr(model.model, "encoder_output_size")
|
||||
else -1
|
||||
)
|
||||
|
||||
audio_encoder = (
|
||||
model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
|
||||
)
|
||||
|
||||
# self.frontend = frontend
|
||||
|
||||
elif hub == "hf":
|
||||
pass
|
||||
else:
|
||||
encoder_class = tables.encoder_classes.get(audio_encoder)
|
||||
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
|
||||
audio_encoder_output_size = audio_encoder.output_size()
|
||||
freeze = audio_encoder_conf.get("freeze", True)
|
||||
freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
|
||||
# if freeze_layer_num > 0:
|
||||
# freeze_layer_num = range(freeze_layer_num)
|
||||
|
||||
if freeze:
|
||||
for name, param in audio_encoder.named_parameters():
|
||||
if freeze_layer_num > 0:
|
||||
idx = re.search(r"\.\d+\.", name)
|
||||
if idx is not None:
|
||||
beg, end = idx.regs[0]
|
||||
layer_id = int(name[beg + 1 : end - 1])
|
||||
if layer_id < freeze_layer_num:
|
||||
param.requires_grad = False
|
||||
elif "ln_post." not in name:
|
||||
param.requires_grad = False
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
audio_encoder.eval()
|
||||
|
||||
self.audio_encoder = audio_encoder
|
||||
|
||||
# llm
|
||||
self.llm = None
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
||||
|
||||
init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
|
||||
llm_load_kwargs = llm_conf.get("load_kwargs", {})
|
||||
|
||||
if not llm_conf.get("low_cpu", False):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
init_param_path,
|
||||
load_in_8bit=None,
|
||||
device_map=None,
|
||||
use_cache=None,
|
||||
**llm_load_kwargs,
|
||||
)
|
||||
else:
|
||||
import os
|
||||
|
||||
if int(os.environ.get("RANK", 0)) == 0:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
init_param_path,
|
||||
load_in_8bit=None,
|
||||
device_map="cpu",
|
||||
use_cache=None,
|
||||
**llm_load_kwargs,
|
||||
)
|
||||
else:
|
||||
llm_config = AutoConfig.from_pretrained(init_param_path)
|
||||
model = AutoModelForCausalLM.from_config(llm_config)
|
||||
|
||||
freeze = llm_conf.get("freeze", True)
|
||||
if freeze:
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad = False
|
||||
model.eval()
|
||||
|
||||
logging.info(f"use_lora: {llm_conf.get('use_lora', False)}")
|
||||
if llm_conf.get("use_lora", False):
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
|
||||
lora_conf = llm_conf.get("lora_conf", {})
|
||||
if isinstance(lora_conf, (OmegaConf, DictConfig)):
|
||||
lora_conf = OmegaConf.to_container(lora_conf, resolve=True)
|
||||
from peft import get_peft_model, LoraConfig, TaskType, PeftConfig, PeftModel
|
||||
|
||||
lora_init_param_path = lora_conf.get("init_param_path", None)
|
||||
if lora_init_param_path is not None:
|
||||
logging.info(f"lora_init_param_path: {lora_init_param_path}")
|
||||
model = PeftModel.from_pretrained(model, lora_init_param_path)
|
||||
for name, param in model.named_parameters():
|
||||
if not lora_conf.get("freeze_lora", False):
|
||||
if "lora_" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
peft_config = LoraConfig(**lora_conf)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
if llm_conf.get("activation_checkpoint", False):
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
|
||||
self.llm = model.to(dtype_map[self.llm_dtype])
|
||||
llm_dim = model.get_input_embeddings().weight.shape[-1]
|
||||
|
||||
# adaptor
|
||||
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
|
||||
if audio_encoder_output_size > 0:
|
||||
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
|
||||
audio_adaptor_conf["llm_dim"] = llm_dim
|
||||
audio_adaptor = adaptor_class(**audio_adaptor_conf)
|
||||
init_param_path = audio_adaptor_conf.get("init_param_path", None)
|
||||
if init_param_path is not None:
|
||||
src_state = torch.load(init_param_path, map_location="cpu")
|
||||
flag = audio_adaptor.load_state_dict(src_state, strict=False)
|
||||
logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}")
|
||||
freeze = audio_adaptor_conf.get("freeze", False)
|
||||
if freeze:
|
||||
for name, param in audio_adaptor.named_parameters():
|
||||
param.requires_grad = False
|
||||
audio_adaptor.eval()
|
||||
|
||||
self.audio_adaptor = audio_adaptor
|
||||
|
||||
self.error_calculator = None
|
||||
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.beam_search = None
|
||||
import os
|
||||
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
logging.info(f"rank: {rank}, model is builded.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor = None,
|
||||
speech_lengths: torch.Tensor = None,
|
||||
input_ids: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels_ids: torch.Tensor = None,
|
||||
fbank_beg: torch.Tensor = None,
|
||||
fbank_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
# import pdb
|
||||
#
|
||||
# pdb.set_trace()
|
||||
batch_size, token_num = input_ids.shape
|
||||
stats = {}
|
||||
input_ids[input_ids < 0] = 0
|
||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||
if speech is not None:
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size_speech, frames, _ = speech.shape
|
||||
|
||||
# audio encoder
|
||||
if self.audio_encoder_activation_checkpoint:
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
encoder_out, encoder_out_lens = checkpoint(
|
||||
self.encode, speech, speech_lengths, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# audio_adaptor
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
fake_token_len = kwargs.get("fake_token_len")
|
||||
fake_token_len[fake_token_len < 0] = 0
|
||||
fbank_beg[fbank_beg < 0] = 0
|
||||
|
||||
speech_idx = 0
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
for turn_id in range(fbank_beg.shape[1]):
|
||||
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
||||
if fbank_beg_idx > 0:
|
||||
speech_token_len = fake_token_len[batch_idx, turn_id]
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
|
||||
try:
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
except Exception as e:
|
||||
#
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
||||
)
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
|
||||
speech_idx += 1
|
||||
|
||||
stats["batch_size_speech"] = batch_size_speech
|
||||
stats["batch_size_x_frames"] = frames * batch_size_speech
|
||||
stats["batch_size_real_frames"] = speech_lengths.sum().item()
|
||||
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
||||
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
|
||||
):
|
||||
labels_ids[labels_ids == -1] = -100
|
||||
attention_mask[attention_mask < 0] = 0
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]),
|
||||
attention_mask=attention_mask,
|
||||
labels=labels_ids,
|
||||
)
|
||||
loss = model_outputs.loss
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
|
||||
stats["acc"] = acc_att
|
||||
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
stats["batch_size"] = batch_size
|
||||
|
||||
stats["batch_size_x_tokens"] = token_num * batch_size
|
||||
stats["batch_size_real_tokens"] = attention_mask.sum().item()
|
||||
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
|
||||
|
||||
dialog_turns = (fbank_beg > 0).sum(-1)
|
||||
dialog_turns_max = torch.max(dialog_turns).int().item()
|
||||
dialog_turns_avg = dialog_turns.sum().item() / batch_size
|
||||
stats["dialog_turns_max"] = dialog_turns_max
|
||||
stats["dialog_turns_avg"] = dialog_turns_avg
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = int((labels_ids > 0 + 1).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode(self, speech, speech_lengths):
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def data_template(self, data):
|
||||
system, user, assistant = [], [], []
|
||||
for i, item in enumerate(data):
|
||||
role = item["role"]
|
||||
content = item["content"]
|
||||
if role == "system":
|
||||
system.append(content)
|
||||
elif role == "user":
|
||||
if "audio" in item:
|
||||
audio = item["audio"]
|
||||
content = [content, audio]
|
||||
user.append(content)
|
||||
elif role == "assistant":
|
||||
assistant.append(content)
|
||||
|
||||
system = system * len(user)
|
||||
|
||||
contents = {
|
||||
"system": system,
|
||||
"user": user,
|
||||
"assistant": assistant,
|
||||
}
|
||||
|
||||
return contents
|
||||
|
||||
def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
|
||||
|
||||
system = contents["system"]
|
||||
user = contents["user"]
|
||||
assistant = contents["assistant"]
|
||||
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
|
||||
|
||||
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
input_source_ids = []
|
||||
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
|
||||
if i >= kwargs.get("multiturn_num_max", 5):
|
||||
break
|
||||
if len(input_ids) > kwargs.get("max_token_length", 1500):
|
||||
break
|
||||
if isinstance(user_prompt, (list, tuple)):
|
||||
user_prompt, audio = user_prompt
|
||||
if i == 0:
|
||||
if kwargs.get("infer_with_assistant_input", False):
|
||||
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}"
|
||||
else:
|
||||
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
else:
|
||||
if kwargs.get("infer_with_assistant_input", False):
|
||||
source_input = f"<|im_start|>user\n{user_prompt}"
|
||||
else:
|
||||
source_input = (
|
||||
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
splits = pattern.split(source_input)
|
||||
source_ids = []
|
||||
fbank_i = []
|
||||
fbank_mask_i = []
|
||||
fake_token_len_i = 0
|
||||
fbank_beg_i = -1
|
||||
fbank_lens_i = []
|
||||
speech, speech_lengths = [], []
|
||||
for k, sub_str in enumerate(splits):
|
||||
if not sub_str.startswith("<|startofspeech|>"):
|
||||
sub_token = tokenizer.encode(sub_str)
|
||||
source_ids += sub_token
|
||||
fbank_mask_i += [0] * len(sub_token)
|
||||
else:
|
||||
sub_str = sub_str.replace("<|startofspeech|>", "").replace(
|
||||
"<|endofspeech|>", ""
|
||||
)
|
||||
if sub_str.startswith("!"):
|
||||
sub_str = sub_str[1:]
|
||||
if sub_str.startswith("!"): # !!: audio sample point
|
||||
sub_str = audio
|
||||
try:
|
||||
time1 = time.perf_counter()
|
||||
data_src = load_audio_text_image_video(sub_str, fs=frontend.fs)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
except Exception as e:
|
||||
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
|
||||
|
||||
speech, speech_lengths = extract_fbank(
|
||||
data_src,
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend,
|
||||
is_final=True,
|
||||
) # speech: [b, T, d]
|
||||
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = (
|
||||
speech_lengths.sum().item()
|
||||
* frontend.frame_shift
|
||||
* frontend.lfr_n
|
||||
/ 1000
|
||||
)
|
||||
|
||||
if kwargs.get("permute", True):
|
||||
speech = speech.permute(0, 2, 1)
|
||||
if speech_lengths > kwargs.get("max_source_length", 5500):
|
||||
# logging.info(
|
||||
# f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
|
||||
# )
|
||||
badcase_flag = True
|
||||
|
||||
olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
|
||||
olens = 1 + (olens - 3 + 2 * 1) // 2
|
||||
fake_token_len_i = (olens - 1) // 2 + 1
|
||||
fake_token = [0] * fake_token_len_i
|
||||
fbank_beg_i = len(source_ids)
|
||||
source_ids += fake_token
|
||||
fbank_mask_i += [1] * len(fake_token)
|
||||
|
||||
fbank_beg += [fbank_beg_i + len(input_ids)]
|
||||
fake_token_len += [fake_token_len_i]
|
||||
source_mask = [-100] * len(source_ids)
|
||||
target_out = f"{target_out}<|im_end|>"
|
||||
target_ids = tokenizer.encode(target_out)
|
||||
input_source_ids = input_ids + source_ids
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
fbank_mask += fbank_mask_i
|
||||
if len(speech) > 0:
|
||||
fbank.append(speech[0, :, :])
|
||||
fbank_lens.append(speech_lengths)
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
|
||||
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
|
||||
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
|
||||
|
||||
# fbank = speech[0, :, :]
|
||||
# fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
|
||||
fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
|
||||
fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
|
||||
fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
|
||||
source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
|
||||
target_ids = torch.tensor(target_ids, dtype=torch.int64)
|
||||
|
||||
if len(fbank) > 0:
|
||||
speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
|
||||
speech_lengths = torch.nn.utils.rnn.pad_sequence(
|
||||
fbank_lens, batch_first=True, padding_value=-1
|
||||
)
|
||||
else:
|
||||
speech = []
|
||||
speech_lengths = []
|
||||
output = {
|
||||
"speech": speech,
|
||||
"speech_lengths": speech_lengths,
|
||||
"fbank_mask": fbank_mask[None, :],
|
||||
"fbank_beg": fbank_beg[None,],
|
||||
"fake_token_len": fake_token_len[None, :],
|
||||
"input_ids": input_ids[None,],
|
||||
"attention_mask": attention_mask[None,],
|
||||
"labels_ids": labels,
|
||||
"source_ids": source_ids[None, :],
|
||||
"target_ids": target_ids[None, :],
|
||||
}
|
||||
|
||||
return output
|
||||
|
||||
def inference_prepare(
|
||||
self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
meta_data = {}
|
||||
prompt = kwargs.get("prompt", None)
|
||||
|
||||
if kwargs.get("batch_size", 1) > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
contents = self.data_template(data_in[0])
|
||||
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
|
||||
batch = to_device(output, kwargs["device"])
|
||||
|
||||
# audio encoder
|
||||
speech = batch["speech"]
|
||||
|
||||
if len(speech) > 0:
|
||||
if "audio_embedding" in kwargs and "audio_embedding_lens" in kwargs:
|
||||
encoder_out = kwargs["audio_embedding"]
|
||||
encoder_out_lens = kwargs["audio_embedding_lens"]
|
||||
else:
|
||||
speech_lengths = batch["speech_lengths"][:, 0]
|
||||
# fp16
|
||||
if kwargs.get("fp16", False):
|
||||
speech = speech.to(torch.float16)
|
||||
elif kwargs.get("bf16", False):
|
||||
speech = speech.to(torch.bfloat16)
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# audio_adaptor
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||
|
||||
meta_data["audio_adaptor_out"] = encoder_out
|
||||
meta_data["audio_adaptor_out_lens"] = encoder_out_lens
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
source_ids = batch["source_ids"]
|
||||
fbank_beg = batch["fbank_beg"]
|
||||
fake_token_len = batch["fake_token_len"]
|
||||
|
||||
if not kwargs.get("tearchforing", False):
|
||||
input_ids = source_ids
|
||||
|
||||
input_ids[input_ids < 0] = 0
|
||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
|
||||
fake_token_len[fake_token_len < 0] = 0
|
||||
fbank_beg[fbank_beg < 0] = 0
|
||||
|
||||
speech_idx = 0
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
for turn_id in range(fbank_beg.shape[1]):
|
||||
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
||||
if fbank_beg_idx > 0:
|
||||
speech_token_len = fake_token_len[batch_idx, turn_id]
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
|
||||
try:
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
except Exception as e:
|
||||
#
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
||||
)
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
|
||||
speech_idx += 1
|
||||
return inputs_embeds, contents, batch, source_ids, meta_data
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
|
||||
data_in, data_lengths, key, tokenizer, frontend, **kwargs
|
||||
)
|
||||
|
||||
llm_dtype = kwargs.get("llm_dtype", "fp32")
|
||||
if llm_dtype == "fp32":
|
||||
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
|
||||
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
|
||||
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
|
||||
):
|
||||
label = contents["assistant"][-1]
|
||||
self.llm = self.llm.to(dtype_map[llm_dtype])
|
||||
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
||||
llm_kwargs = kwargs.get("llm_kwargs", {})
|
||||
if not kwargs.get("tearchforing", False):
|
||||
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_new_tokens=kwargs.get("max_length", 512),
|
||||
**llm_kwargs,
|
||||
)
|
||||
# generated_ids = [
|
||||
# output_ids[len(input_id) :]
|
||||
# for input_id, output_ids in zip(input_ids, generated_ids)
|
||||
# ]
|
||||
response = tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
|
||||
)[0]
|
||||
|
||||
loss = None
|
||||
else:
|
||||
|
||||
labels_ids = batch["labels_ids"]
|
||||
labels_ids[labels_ids == -1] = -100
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
# attention_mask = attention_mask.to(dtype_map[llm_dtype])
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels_ids,
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
|
||||
response = tokenizer.batch_decode(
|
||||
preds,
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=kwargs.get("skip_special_tokens", True),
|
||||
)[0]
|
||||
loss = model_outputs.loss.item()
|
||||
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
if not hasattr(self, "writer"):
|
||||
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = self.writer[f"{0 + 1}best_recog"]
|
||||
|
||||
results = []
|
||||
response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
|
||||
result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label}
|
||||
if loss is not None:
|
||||
result_i["loss"] = loss
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["text"][key[0]] = response.replace("\n", " ")
|
||||
ibest_writer["label"][key[0]] = label.replace("\n", " ")
|
||||
ibest_writer["text_tn"][key[0]] = response_clean
|
||||
|
||||
return results, meta_data
|
||||
|
||||
Loading…
Reference in New Issue
Block a user