Skip to content

Commit 59a8068

Browse files
authored
fix auto modeling resolve_cache_dir (PaddlePaddle#454)
1 parent 6272cbd commit 59a8068

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

paddlemix/auto/modeling.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@
1616
import json
1717
import os
1818
from collections import defaultdict
19+
from typing import Optional
1920

2021
from huggingface_hub import hf_hub_download
2122
from paddlenlp.transformers.configuration_utils import is_standard_config
2223
from paddlenlp.transformers.model_utils import PretrainedModel
23-
from paddlenlp.transformers.utils import resolve_cache_dir
2424
from paddlenlp.utils.downloader import (
2525
COMMUNITY_MODEL_PREFIX,
2626
get_path_from_url_with_filelock,
2727
hf_file_exists,
2828
url_file_exists,
2929
)
30+
from paddlenlp.utils.env import HF_CACHE_HOME as PPNLP_HF_CACHE_HOME
31+
from paddlenlp.utils.env import MODEL_HOME as PPNLP_MODEL_HOME
3032
from paddlenlp.utils.import_utils import import_module
3133
from paddlenlp.utils.log import logger
3234

@@ -57,6 +59,22 @@
5759
}
5860

5961

62+
def resolve_cache_dir(from_hf_hub: bool, from_aistudio: bool, cache_dir: Optional[str] = None) -> str:
63+
"""resolve cache dir for PretrainedModel and PretrainedConfig
64+
65+
Args:
66+
from_hf_hub (bool): if load from huggingface hub
67+
cache_dir (str): cache_dir for models
68+
"""
69+
if cache_dir is not None:
70+
return cache_dir
71+
if from_aistudio:
72+
return None
73+
if from_hf_hub:
74+
return PPNLP_HF_CACHE_HOME
75+
return PPNLP_MODEL_HOME
76+
77+
6078
def get_model_mapping():
6179

6280
# 1. search the subdir<model-name> to find model-names

0 commit comments

Comments
 (0)