This commit is contained in:
游雁 2024-06-12 19:53:59 +08:00
parent 83d644c899
commit f6cae2b48b

View File

@ -61,7 +61,7 @@ def download_from_ms(**kwargs):
): ):
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml")) config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
kwargs = OmegaConf.merge(config, kwargs) kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pb") init_param = os.path.join(model_or_path, "model.pt")
kwargs["init_param"] = init_param kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")): if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt") kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
@ -122,7 +122,7 @@ def download_from_hf(**kwargs):
): ):
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml")) config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
kwargs = OmegaConf.merge(config, kwargs) kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pb") init_param = os.path.join(model_or_path, "model.pt")
kwargs["init_param"] = init_param kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")): if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt") kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")