FunASR/egs/aishell/transformer/utils/prepare_checkpoint.py
2023-08-07 11:12:26 +08:00

20 lines
664 B
Python

import os
import shutil
from modelscope.hub.snapshot_download import snapshot_download
if __name__ == '__main__':
import sys
model = sys.argv[1]
checkpoint_dir = sys.argv[2]
checkpoint_name = sys.argv[3]
try:
pretrained_model_path = snapshot_download(model, cache_dir=checkpoint_dir)
except BaseException:
raise BaseException(f"Please download pretrain model from ModelScope firstly.")
shutil.copy(os.path.join(checkpoint_dir, checkpoint_name), os.path.join(pretrained_model_path, "model.pb"))
shutil.copy(os.path.join(checkpoint_dir, "config.yaml"), os.path.join(pretrained_model_path, "config.yaml"))