mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'alibaba-damo-academy:main' into main
This commit is contained in:
commit
a1b63ca9cc
@ -52,7 +52,7 @@ asr_config=conf/train_asr_conformer.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
|
||||
@ -55,7 +55,7 @@ asr_config=conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
|
||||
@ -55,7 +55,7 @@ asr_config=conf/train_asr_transformer_12e_6d_3072_768.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer.yaml
|
||||
inference_asr_model=valid.cer_ctc.ave_10best.pth
|
||||
inference_asr_model=valid.cer_ctc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
|
||||
@ -52,7 +52,7 @@ asr_config=conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
|
||||
@ -56,7 +56,7 @@ asr_config=conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
|
||||
@ -52,7 +52,7 @@ asr_config=conf/train_asr_conformer.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
|
||||
@ -54,7 +54,7 @@ asr_config=conf/train_asr_conformer.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
|
||||
|
||||
@ -54,7 +54,7 @@ asr_config=conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
|
||||
|
||||
@ -58,7 +58,7 @@ asr_config=conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
|
||||
|
||||
@ -54,7 +54,7 @@ asr_config=conf/train_asr_transformer.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
|
||||
|
||||
@ -34,7 +34,7 @@ exp_dir=./data
|
||||
tag=exp1
|
||||
model_dir="baseline_$(basename "${lm_config}" .yaml)_${lang}_${token_type}_${tag}"
|
||||
lm_exp=${exp_dir}/exp/${model_dir}
|
||||
inference_lm=valid.loss.ave.pth # Language model path for decoding.
|
||||
inference_lm=valid.loss.ave.pb # Language model path for decoding.
|
||||
|
||||
stage=0
|
||||
stop_stage=3
|
||||
|
||||
@ -4,7 +4,7 @@ import sys
|
||||
|
||||
def main():
|
||||
diar_config_path = sys.argv[1] if len(sys.argv) > 1 else "sond_fbank.yaml"
|
||||
diar_model_path = sys.argv[2] if len(sys.argv) > 2 else "sond.pth"
|
||||
diar_model_path = sys.argv[2] if len(sys.argv) > 2 else "sond.pb"
|
||||
output_dir = sys.argv[3] if len(sys.argv) > 3 else "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/test_rmsil/feats.scp", "speech", "kaldi_ark"),
|
||||
|
||||
@ -17,9 +17,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Downloading Pre-trained model..."
|
||||
git clone https://www.modelscope.cn/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch.git
|
||||
git clone https://www.modelscope.cn/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch.git
|
||||
ln -s speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth ./sv.pth
|
||||
ln -s speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pb ./sv.pb
|
||||
cp speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.yaml ./sv.yaml
|
||||
ln -s speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.pth ./sond.pth
|
||||
ln -s speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.pb ./sond.pb
|
||||
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond_fbank.yaml ./sond_fbank.yaml
|
||||
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.yaml ./sond.yaml
|
||||
echo "Done."
|
||||
@ -30,7 +30,7 @@ fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Calculating diarization results..."
|
||||
python infer_alimeeting_test.py sond_fbank.yaml sond.pth outputs
|
||||
python infer_alimeeting_test.py sond_fbank.yaml sond.pb outputs
|
||||
python local/convert_label_to_rttm.py \
|
||||
outputs/labels.txt \
|
||||
data/test_rmsil/raw_rmsil_map.scp \
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
|
||||
def test_fbank_cpu_infer():
|
||||
diar_config_path = "config_fbank.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
|
||||
@ -24,7 +24,7 @@ def test_fbank_cpu_infer():
|
||||
|
||||
def test_fbank_gpu_infer():
|
||||
diar_config_path = "config_fbank.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
|
||||
@ -45,7 +45,7 @@ def test_fbank_gpu_infer():
|
||||
|
||||
def test_wav_gpu_infer():
|
||||
diar_config_path = "config.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_wav.scp", "speech", "sound"),
|
||||
@ -66,7 +66,7 @@ def test_wav_gpu_infer():
|
||||
|
||||
def test_without_profile_gpu_infer():
|
||||
diar_config_path = "config.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
raw_inputs = [[
|
||||
"data/unit_test/raw_inputs/record.wav",
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
|
||||
def test_fbank_cpu_infer():
|
||||
diar_config_path = "sond_fbank.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
|
||||
@ -24,7 +24,7 @@ def test_fbank_cpu_infer():
|
||||
|
||||
def test_fbank_gpu_infer():
|
||||
diar_config_path = "sond_fbank.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
|
||||
@ -45,7 +45,7 @@ def test_fbank_gpu_infer():
|
||||
|
||||
def test_wav_gpu_infer():
|
||||
diar_config_path = "config.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_wav.scp", "speech", "sound"),
|
||||
@ -66,7 +66,7 @@ def test_wav_gpu_infer():
|
||||
|
||||
def test_without_profile_gpu_infer():
|
||||
diar_config_path = "config.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
diar_model_path = "sond.pb"
|
||||
output_dir = "./outputs"
|
||||
raw_inputs = [[
|
||||
"data/unit_test/raw_inputs/record.wav",
|
||||
|
||||
@ -49,7 +49,7 @@ asr_config=conf/train_asr_conformer.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
inference_asr_model=valid.acc.ave_10best.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
|
||||
@ -41,7 +41,7 @@ The decoding results can be found in `$output_dir/1best_recog/text.cer`, which i
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed~~~~
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -48,5 +48,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "valid.cer_ctc.ave.pth"
|
||||
params["decoding_model_name"] = "valid.cer_ctc.ave.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -41,7 +41,7 @@ The decoding results can be found in `$output_dir/1best_recog/text.cer`, which i
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed~~~~
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -48,5 +48,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "valid.cer_ctc.ave.pth"
|
||||
params["decoding_model_name"] = "valid.cer_ctc.ave.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -41,7 +41,7 @@ The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -63,5 +63,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["feats_stats.npz", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./example_data/validation"
|
||||
params["decoding_model_name"] = "valid.acc.ave.pth"
|
||||
params["decoding_model_name"] = "valid.acc.ave.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -49,5 +49,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pth"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -49,5 +49,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pth"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -41,7 +41,7 @@ The decoding results can be found in `$output_dir/1best_recog/text.cer`, which i
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -49,5 +49,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pth"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -49,5 +49,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pth"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -41,7 +41,7 @@ The decoding results can be found in `$output_dir/1best_recog/text.cer`, which i
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -50,5 +50,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "20epoch.pth"
|
||||
params["decoding_model_name"] = "20epoch.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -41,7 +41,7 @@ The decoding results can be found in `$output_dir/1best_recog/text.cer`, which i
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -50,5 +50,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "20epoch.pth"
|
||||
params["decoding_model_name"] = "20epoch.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -41,7 +41,7 @@ The decoding results can be found in `$output_dir/1best_recog/text.cer`, which i
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -50,5 +50,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "20epoch.pth"
|
||||
params["decoding_model_name"] = "20epoch.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -41,7 +41,7 @@ The decoding results can be found in `$output_dir/1best_recog/text.cer`, which i
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -49,5 +49,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "20epoch.pth"
|
||||
params["decoding_model_name"] = "20epoch.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -41,7 +41,8 @@ The decoding results can be found in `$output_dir/1best_recog/text.cer`, which i
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave
|
||||
.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -49,5 +49,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "20epoch.pth"
|
||||
params["decoding_model_name"] = "20epoch.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -34,7 +34,7 @@ Or you can use the finetuned model for inference directly.
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
|
||||
@ -53,5 +53,5 @@ if __name__ == '__main__':
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json", "punc/punc.pb", "punc/punc.yaml", "vad/vad.mvn", "vad/vad.pb", "vad/vad.yaml"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pth"
|
||||
params["decoding_model_name"] = "valid.acc.ave_10best.pb"
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
@ -52,7 +52,7 @@ class Speech2Text:
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -55,7 +55,7 @@ class Speech2Text:
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -50,7 +50,7 @@ class Speech2Text:
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -58,7 +58,7 @@ class Speech2Text:
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -49,7 +49,7 @@ class Speech2Text:
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -46,7 +46,7 @@ class Speech2Text:
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -46,7 +46,7 @@ class Speech2Text:
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -133,7 +133,7 @@ def inference_launch(mode, **kwargs):
|
||||
param_dict = {
|
||||
"extract_profile": True,
|
||||
"sv_train_config": "sv.yaml",
|
||||
"sv_model_file": "sv.pth",
|
||||
"sv_model_file": "sv.pb",
|
||||
}
|
||||
if "param_dict" in kwargs and kwargs["param_dict"] is not None:
|
||||
for key in param_dict:
|
||||
|
||||
@ -35,7 +35,7 @@ class Speech2Diarization:
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import numpy as np
|
||||
>>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
|
||||
>>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
|
||||
>>> profile = np.load("profiles.npy")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2diar(audio, profile)
|
||||
@ -209,7 +209,7 @@ def inference_modelscope(
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
raw_inputs = raw_inputs.numpy()
|
||||
data_path_and_name_and_type = [raw_inputs[0], "speech", "bytes"]
|
||||
data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
|
||||
loader = EENDOLADiarTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
|
||||
@ -42,7 +42,7 @@ class Speech2Diarization:
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import numpy as np
|
||||
>>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
|
||||
>>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
|
||||
>>> profile = np.load("profiles.npy")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2diar(audio, profile)
|
||||
|
||||
@ -36,7 +36,7 @@ class Speech2Xvector:
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pth")
|
||||
>>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2xvector(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
@ -169,7 +169,7 @@ def inference_modelscope(
|
||||
log_level: Union[int, str] = "INFO",
|
||||
key_file: Optional[str] = None,
|
||||
sv_train_config: Optional[str] = "sv.yaml",
|
||||
sv_model_file: Optional[str] = "sv.pth",
|
||||
sv_model_file: Optional[str] = "sv.pb",
|
||||
model_tag: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = True,
|
||||
streaming: bool = False,
|
||||
|
||||
@ -66,13 +66,13 @@ def average_nbest_models(
|
||||
elif n == 1:
|
||||
# The averaged model is same as the best model
|
||||
e, _ = epoch_and_values[0]
|
||||
op = output_dir / f"{e}epoch.pth"
|
||||
sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pth"
|
||||
op = output_dir / f"{e}epoch.pb"
|
||||
sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
|
||||
if sym_op.is_symlink() or sym_op.exists():
|
||||
sym_op.unlink()
|
||||
sym_op.symlink_to(op.name)
|
||||
else:
|
||||
op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pth"
|
||||
op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
|
||||
logging.info(
|
||||
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
|
||||
)
|
||||
@ -83,12 +83,12 @@ def average_nbest_models(
|
||||
if e not in _loaded:
|
||||
if oss_bucket is None:
|
||||
_loaded[e] = torch.load(
|
||||
output_dir / f"{e}epoch.pth",
|
||||
output_dir / f"{e}epoch.pb",
|
||||
map_location="cpu",
|
||||
)
|
||||
else:
|
||||
buffer = BytesIO(
|
||||
oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pth")).read())
|
||||
oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
|
||||
_loaded[e] = torch.load(buffer)
|
||||
states = _loaded[e]
|
||||
|
||||
@ -115,13 +115,13 @@ def average_nbest_models(
|
||||
else:
|
||||
buffer = BytesIO()
|
||||
torch.save(avg, buffer)
|
||||
oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pth"),
|
||||
oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
|
||||
buffer.getvalue())
|
||||
|
||||
# 3. *.*.ave.pth is a symlink to the max ave model
|
||||
# 3. *.*.ave.pb is a symlink to the max ave model
|
||||
if oss_bucket is None:
|
||||
op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pth"
|
||||
sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pth"
|
||||
op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
|
||||
sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
|
||||
if sym_op.is_symlink() or sym_op.exists():
|
||||
sym_op.unlink()
|
||||
sym_op.symlink_to(op.name)
|
||||
|
||||
@ -191,12 +191,12 @@ def unpack(
|
||||
|
||||
Examples:
|
||||
tarfile:
|
||||
model.pth
|
||||
model.pb
|
||||
some1.file
|
||||
some2.file
|
||||
|
||||
>>> unpack("tarfile", "out")
|
||||
{'asr_model_file': 'out/model.pth'}
|
||||
{'asr_model_file': 'out/model.pb'}
|
||||
"""
|
||||
input_archive = Path(input_archive)
|
||||
outpath = Path(outpath)
|
||||
|
||||
@ -87,7 +87,7 @@ class EENDOLATransformerEncoder(nn.Module):
|
||||
n_layers: int,
|
||||
n_units: int,
|
||||
e_units: int = 2048,
|
||||
h: int = 8,
|
||||
h: int = 4,
|
||||
dropout_rate: float = 0.1,
|
||||
use_pos_emb: bool = False):
|
||||
super(EENDOLATransformerEncoder, self).__init__()
|
||||
|
||||
21
funasr/runtime/python/README.md
Normal file
21
funasr/runtime/python/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
Benchmark [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) based on Aishell1 test set , the total audio duration is 36108.919 seconds.
|
||||
|
||||
(Note: The service has been fully warm up.)
|
||||
|
||||
Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz 16core-32processor with avx512_vnni
|
||||
|
||||
| concurrent-tasks | processing time(s) | RTF | Speedup Rate |
|
||||
|:----------------:|:------------------:|:------:|:------------:|
|
||||
| 1 (onnx fp32) | 2806 | 0.0777 | 12.9 |
|
||||
| 1 (onnx int8) | 1611 | 0.0446 | 22.4 |
|
||||
| 8 (onnx fp32) | 538 | 0.0149 | 67.1 |
|
||||
| 8 (onnx int8) | 210 | 0.0058 | 172.4 |
|
||||
| 16 (onnx fp32) | 288 | 0.0080 | 125.2 |
|
||||
| 16 (onnx int8) | 117 | 0.0032 | 309.9 |
|
||||
| 32 (onnx fp32) | 167 | 0.0046 | 216.5 |
|
||||
| 32 (onnx int8) | 107 | 0.0030 | 338.0 |
|
||||
| 64 (onnx fp32) | 158 | 0.0044 | 228.1 |
|
||||
| 64 (onnx int8) | 82 | 0.0023 | 442.8 |
|
||||
| 96 (onnx fp32) | 151 | 0.0042 | 238.0 |
|
||||
| 96 (onnx int8) | 80 | 0.0022 | 452.0 |
|
||||
|
||||
@ -639,12 +639,12 @@ class AbsTask(ABC):
|
||||
"and exclude_keys excludes keys of model states for the initialization."
|
||||
"e.g.\n"
|
||||
" # Load all parameters"
|
||||
" --init_param some/where/model.pth\n"
|
||||
" --init_param some/where/model.pb\n"
|
||||
" # Load only decoder parameters"
|
||||
" --init_param some/where/model.pth:decoder:decoder\n"
|
||||
" --init_param some/where/model.pb:decoder:decoder\n"
|
||||
" # Load only decoder parameters excluding decoder.embed"
|
||||
" --init_param some/where/model.pth:decoder:decoder:decoder.embed\n"
|
||||
" --init_param some/where/model.pth:decoder:decoder:decoder.embed\n",
|
||||
" --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
|
||||
" --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ignore_init_mismatch",
|
||||
|
||||
@ -826,7 +826,7 @@ class ASRTaskUniASR(ASRTask):
|
||||
if "model.ckpt-" in model_name or ".bin" in model_name:
|
||||
model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
|
||||
'.pb')) if ".bin" in model_name else os.path.join(
|
||||
model_dir, "{}.pth".format(model_name))
|
||||
model_dir, "{}.pb".format(model_name))
|
||||
if os.path.exists(model_name_pth):
|
||||
logging.info("model_file is load from pth: {}".format(model_name_pth))
|
||||
model_dict = torch.load(model_name_pth, map_location=device)
|
||||
@ -1073,7 +1073,7 @@ class ASRTaskParaformer(ASRTask):
|
||||
if "model.ckpt-" in model_name or ".bin" in model_name:
|
||||
model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
|
||||
'.pb')) if ".bin" in model_name else os.path.join(
|
||||
model_dir, "{}.pth".format(model_name))
|
||||
model_dir, "{}.pb".format(model_name))
|
||||
if os.path.exists(model_name_pth):
|
||||
logging.info("model_file is load from pth: {}".format(model_name_pth))
|
||||
model_dict = torch.load(model_name_pth, map_location=device)
|
||||
|
||||
@ -553,7 +553,7 @@ class DiarTask(AbsTask):
|
||||
if ".bin" in model_name:
|
||||
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
|
||||
else:
|
||||
model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
|
||||
model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
|
||||
if os.path.exists(model_name_pth):
|
||||
logging.info("model_file is load from pth: {}".format(model_name_pth))
|
||||
model_dict = torch.load(model_name_pth, map_location=device)
|
||||
|
||||
@ -501,7 +501,7 @@ class SVTask(AbsTask):
|
||||
if ".bin" in model_name:
|
||||
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
|
||||
else:
|
||||
model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
|
||||
model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
|
||||
if os.path.exists(model_name_pth):
|
||||
logging.info("model_file is load from pth: {}".format(model_name_pth))
|
||||
model_dict = torch.load(model_name_pth, map_location=device)
|
||||
|
||||
@ -52,13 +52,13 @@ def load_pretrained_model(
|
||||
init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
|
||||
|
||||
Examples:
|
||||
>>> load_pretrained_model("somewhere/model.pth", model)
|
||||
>>> load_pretrained_model("somewhere/model.pth:decoder:decoder", model)
|
||||
>>> load_pretrained_model("somewhere/model.pth:decoder:decoder:", model)
|
||||
>>> load_pretrained_model("somewhere/model.pb", model)
|
||||
>>> load_pretrained_model("somewhere/model.pb:decoder:decoder", model)
|
||||
>>> load_pretrained_model("somewhere/model.pb:decoder:decoder:", model)
|
||||
>>> load_pretrained_model(
|
||||
... "somewhere/model.pth:decoder:decoder:decoder.embed", model
|
||||
... "somewhere/model.pb:decoder:decoder:decoder.embed", model
|
||||
... )
|
||||
>>> load_pretrained_model("somewhere/decoder.pth::decoder", model)
|
||||
>>> load_pretrained_model("somewhere/decoder.pb::decoder", model)
|
||||
"""
|
||||
sps = init_param.split(":", 4)
|
||||
if len(sps) == 4:
|
||||
|
||||
@ -205,9 +205,9 @@ class Trainer:
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
if trainer_options.resume and (output_dir / "checkpoint.pth").exists():
|
||||
if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
|
||||
cls.resume(
|
||||
checkpoint=output_dir / "checkpoint.pth",
|
||||
checkpoint=output_dir / "checkpoint.pb",
|
||||
model=model,
|
||||
optimizers=optimizers,
|
||||
schedulers=schedulers,
|
||||
@ -361,7 +361,7 @@ class Trainer:
|
||||
},
|
||||
buffer,
|
||||
)
|
||||
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pth"), buffer.getvalue())
|
||||
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), buffer.getvalue())
|
||||
else:
|
||||
torch.save(
|
||||
{
|
||||
@ -374,7 +374,7 @@ class Trainer:
|
||||
],
|
||||
"scaler": scaler.state_dict() if scaler is not None else None,
|
||||
},
|
||||
output_dir / "checkpoint.pth",
|
||||
output_dir / "checkpoint.pb",
|
||||
)
|
||||
|
||||
# 5. Save and log the model and update the link to the best model
|
||||
@ -382,22 +382,22 @@ class Trainer:
|
||||
buffer = BytesIO()
|
||||
torch.save(model.state_dict(), buffer)
|
||||
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
|
||||
f"{iepoch}epoch.pth"),buffer.getvalue())
|
||||
f"{iepoch}epoch.pb"),buffer.getvalue())
|
||||
else:
|
||||
torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth")
|
||||
torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
|
||||
|
||||
# Creates a sym link latest.pth -> {iepoch}epoch.pth
|
||||
# Creates a sym link latest.pb -> {iepoch}epoch.pb
|
||||
if trainer_options.use_pai:
|
||||
p = os.path.join(trainer_options.output_dir, "latest.pth")
|
||||
p = os.path.join(trainer_options.output_dir, "latest.pb")
|
||||
if trainer_options.oss_bucket.object_exists(p):
|
||||
trainer_options.oss_bucket.delete_object(p)
|
||||
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
|
||||
os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pth"), p)
|
||||
os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"), p)
|
||||
else:
|
||||
p = output_dir / "latest.pth"
|
||||
p = output_dir / "latest.pb"
|
||||
if p.is_symlink() or p.exists():
|
||||
p.unlink()
|
||||
p.symlink_to(f"{iepoch}epoch.pth")
|
||||
p.symlink_to(f"{iepoch}epoch.pb")
|
||||
|
||||
_improved = []
|
||||
for _phase, k, _mode in trainer_options.best_model_criterion:
|
||||
@ -407,16 +407,16 @@ class Trainer:
|
||||
# Creates sym links if it's the best result
|
||||
if best_epoch == iepoch:
|
||||
if trainer_options.use_pai:
|
||||
p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pth")
|
||||
p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
|
||||
if trainer_options.oss_bucket.object_exists(p):
|
||||
trainer_options.oss_bucket.delete_object(p)
|
||||
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
|
||||
os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pth"),p)
|
||||
os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"),p)
|
||||
else:
|
||||
p = output_dir / f"{_phase}.{k}.best.pth"
|
||||
p = output_dir / f"{_phase}.{k}.best.pb"
|
||||
if p.is_symlink() or p.exists():
|
||||
p.unlink()
|
||||
p.symlink_to(f"{iepoch}epoch.pth")
|
||||
p.symlink_to(f"{iepoch}epoch.pb")
|
||||
_improved.append(f"{_phase}.{k}")
|
||||
if len(_improved) == 0:
|
||||
logging.info("There are no improvements in this epoch")
|
||||
@ -438,7 +438,7 @@ class Trainer:
|
||||
type="model",
|
||||
metadata={"improved": _improved},
|
||||
)
|
||||
artifact.add_file(str(output_dir / f"{iepoch}epoch.pth"))
|
||||
artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
|
||||
aliases = [
|
||||
f"epoch-{iepoch}",
|
||||
"best" if best_epoch == iepoch else "",
|
||||
@ -473,12 +473,12 @@ class Trainer:
|
||||
|
||||
for e in range(1, iepoch):
|
||||
if trainer_options.use_pai:
|
||||
p = os.path.join(trainer_options.output_dir, f"{e}epoch.pth")
|
||||
p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
|
||||
if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
|
||||
trainer_options.oss_bucket.delete_object(p)
|
||||
_removed.append(str(p))
|
||||
else:
|
||||
p = output_dir / f"{e}epoch.pth"
|
||||
p = output_dir / f"{e}epoch.pb"
|
||||
if p.exists() and e not in nbests:
|
||||
p.unlink()
|
||||
_removed.append(str(p))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user