mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
hf hub
This commit is contained in:
parent
fffbefc28b
commit
fb45c9a6ef
@ -10,7 +10,7 @@ def download_model(**kwargs):
|
||||
if hub == "ms":
|
||||
kwargs = download_from_ms(**kwargs)
|
||||
elif hub == "hf":
|
||||
pass
|
||||
kwargs = download_from_hf(**kwargs)
|
||||
elif hub == "openai":
|
||||
model_or_path = kwargs.get("model")
|
||||
if os.path.exists(model_or_path):
|
||||
@ -87,6 +87,67 @@ def download_from_ms(**kwargs):
|
||||
return kwargs
|
||||
|
||||
|
||||
def download_from_hf(**kwargs):
|
||||
model_or_path = kwargs.get("model")
|
||||
if model_or_path in name_maps_hf:
|
||||
model_or_path = name_maps_hf[model_or_path]
|
||||
model_revision = kwargs.get("model_revision", "master")
|
||||
if not os.path.exists(model_or_path) and "model_path" not in kwargs:
|
||||
try:
|
||||
model_or_path = get_or_download_model_dir_hf(
|
||||
model_or_path,
|
||||
model_revision,
|
||||
is_training=kwargs.get("is_training"),
|
||||
check_latest=kwargs.get("check_latest", True),
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Download: {model_or_path} failed!: {e}")
|
||||
|
||||
kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
|
||||
|
||||
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
|
||||
with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
|
||||
conf_json = json.load(f)
|
||||
|
||||
cfg = {}
|
||||
if "file_path_metas" in conf_json:
|
||||
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
|
||||
cfg.update(kwargs)
|
||||
if "config" in cfg:
|
||||
config = OmegaConf.load(cfg["config"])
|
||||
kwargs = OmegaConf.merge(config, cfg)
|
||||
kwargs["model"] = config["model"]
|
||||
elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
|
||||
os.path.join(model_or_path, "model.pt")
|
||||
):
|
||||
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
|
||||
kwargs = OmegaConf.merge(config, kwargs)
|
||||
init_param = os.path.join(model_or_path, "model.pb")
|
||||
kwargs["init_param"] = init_param
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
|
||||
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
|
||||
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
|
||||
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
||||
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
|
||||
kwargs["model"] = config["model"]
|
||||
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
||||
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
||||
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
||||
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
||||
if isinstance(kwargs, DictConfig):
|
||||
kwargs = OmegaConf.to_container(kwargs, resolve=True)
|
||||
if os.path.exists(os.path.join(model_or_path, "requirements.txt")):
|
||||
requirements = os.path.join(model_or_path, "requirements.txt")
|
||||
print(f"Detect model requirements, begin to install it: {requirements}")
|
||||
from funasr.utils.install_model_requirements import install_requirements
|
||||
|
||||
install_requirements(requirements)
|
||||
return kwargs
|
||||
|
||||
|
||||
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
|
||||
|
||||
if isinstance(file_path_metas, dict):
|
||||
@ -136,3 +197,22 @@ def get_or_download_model_dir(
|
||||
model, revision=model_revision, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"}
|
||||
)
|
||||
return model_cache_dir
|
||||
|
||||
|
||||
def get_or_download_model_dir_hf(
|
||||
model,
|
||||
model_revision=None,
|
||||
is_training=False,
|
||||
check_latest=True,
|
||||
):
|
||||
"""Get local model directory or download model if necessary.
|
||||
|
||||
Args:
|
||||
model (str): model id or path to local model directory.
|
||||
model_revision (str, optional): model version number.
|
||||
:param is_training:
|
||||
"""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
model_cache_dir = snapshot_download(model)
|
||||
return model_cache_dir
|
||||
|
||||
@ -14,7 +14,9 @@ name_maps_ms = {
|
||||
"Qwen-Audio": "Qwen/Qwen-Audio",
|
||||
}
|
||||
|
||||
name_maps_hf = {}
|
||||
name_maps_hf = {
|
||||
"": "",
|
||||
}
|
||||
|
||||
name_maps_openai = {
|
||||
"Whisper-tiny.en": "tiny.en",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user