Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ class DatasetArguments(CustomDatasetArguments):
default=False,
metadata={"help": "Overwrite the cached preprocessed datasets or not."},
)
cache_dir: Optional[str] = field(
init=False,
default=None,
metadata={
"help": "Where to store the pretrained datasets from huggingface.co. "
"This field is set from model_args.cache_dir to enable unified caching."
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
Expand Down
5 changes: 5 additions & 0 deletions src/llmcompressor/args/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,9 @@ def parse_args(
# silently assign tokenizer to processor
resolve_processor_from_model_args(model_args)

# copy cache_dir from model_args to dataset_args to support offline mode
# with a single unified cache directory. This allows both models and datasets
# to use the same cache when cache_dir is specified
dataset_args.cache_dir = model_args.cache_dir

return model_args, dataset_args, recipe_args, training_args, output_dir
5 changes: 3 additions & 2 deletions src/llmcompressor/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,17 @@ def copy_python_files_from_model_cache(model, save_path: str):
import shutil

from huggingface_hub import hf_hub_download
from transformers import TRANSFORMERS_CACHE
from transformers.utils import http_user_agent

cache_path = config._name_or_path
if not os.path.exists(cache_path):
user_agent = http_user_agent()
# Use cache_dir=None to respect HF_HOME, HF_HUB_CACHE, and other
# environment variables for cache location
config_file_path = hf_hub_download(
repo_id=cache_path,
filename="config.json",
cache_dir=TRANSFORMERS_CACHE,
cache_dir=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't this model_args.cache_dir instead of None?

force_download=False,
user_agent=user_agent,
)
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def load_dataset(self):
logger.debug(f"Loading dataset {self.dataset_args.dataset}")
return get_raw_dataset(
self.dataset_args,
None,
cache_dir=self.dataset_args.cache_dir,
split=self.split,
streaming=self.dataset_args.streaming,
**self.dataset_args.raw_kwargs,
Expand Down