Skip to content

Commit b9efdc6

Browse files
Refactor Dataset.map to reuse cache files mapped with different num_proc (#7434)
* Refactor Dataset.map to reuse cache files mapped with different num_proc Fixes #7433 This refactor unifies num_proc is None or num_proc == 1 and num_proc > 1; instead of handling them completely separately where one uses a list of kwargs and shards and the other just uses a single set of kwargs and self, by wrapping the num_proc == 1 case in a list and making the difference just whether or not you use a pool, you set up either case to be able to load each other cache_files just by changing num_shards; num_proc == 1 can sequentially load the shards of a dataset mapped num_shards > 1 and sequentially map any missing shards Other than the structural refactor, the main contribution of this PR is get_existing_cache_file_map, which uses a regex of cache_file_name and suffix_template to find existing cache files, grouped by their num_shards; using this data structure, we can reset num_shards to an existing set of cache files, and load them accordingly * Only give reprocessing message doing a partial remap also fix spacing in message * Update logging message to account for if a cache file will be written at all and written by the main process or not * Refactor string_to_dict to return None if there is no match instead of raising ValueError instead of having the pattern of using try-except to handle when there is no match, we can instead check if the return value is None; we can also assert that the return value should not be None if we know that should be true * Simplify existing existing_cache_file_map with string_to_dict #7434 (comment) * Set initial value if there are already existing cache files #7434 (comment) * Allow for source_url_fields to be None they can be local file paths here https://github.com/huggingface/datasets/actions/runs/13683185040/job/38380924390?pr=7435#step:10:9731 * Add unicode escape to handle parsing string_to_dict in Windows paths * Remove glob_pattern_to_regex All the tests still pass when it is removed; I think the unicode escaping must do some of the work that glob_pattern_to_regex was doing here before * fix dependencies --------- Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
1 parent 491d808 commit b9efdc6

File tree

5 files changed

+264
-114
lines changed

5 files changed

+264
-114
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@
185185
"zstandard",
186186
"polars[timezone]>=0.20.0",
187187
"torchvision",
188-
"pyav",
188+
"av",
189189
]
190190

191191

src/datasets/arrow_dataset.py

Lines changed: 175 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import contextlib
2020
import copy
2121
import fnmatch
22+
import glob
2223
import inspect
2324
import itertools
2425
import json
@@ -27,12 +28,13 @@
2728
import posixpath
2829
import re
2930
import shutil
31+
import string
3032
import sys
3133
import tempfile
3234
import time
3335
import warnings
3436
import weakref
35-
from collections import Counter
37+
from collections import Counter, defaultdict
3638
from collections.abc import Iterable, Iterator, Mapping
3739
from collections.abc import Sequence as Sequence_
3840
from copy import deepcopy
@@ -2963,6 +2965,11 @@ def map(
29632965
if num_proc is not None and num_proc <= 0:
29642966
raise ValueError("num_proc must be an integer > 0.")
29652967

2968+
string_formatter = string.Formatter()
2969+
fields = {field_name for _, field_name, _, _ in string_formatter.parse(suffix_template) if field_name}
2970+
if fields != {"rank", "num_proc"}:
2971+
raise ValueError(f"suffix_template must contain exactly the fields 'rank' and 'num_proc', got: {fields}")
2972+
29662973
# If the array is empty we do nothing (but we make sure to handle an empty indices mapping and remove the requested columns anyway)
29672974
if len(self) == 0:
29682975
if self._indices is not None: # empty indices mapping
@@ -3045,7 +3052,14 @@ def map(
30453052
cache_file_name = self._get_cache_file_path(new_fingerprint)
30463053
dataset_kwargs["cache_file_name"] = cache_file_name
30473054

3048-
def load_processed_shard_from_cache(shard_kwargs):
3055+
if cache_file_name is not None:
3056+
cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name)
3057+
if not cache_file_ext:
3058+
raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}")
3059+
else:
3060+
cache_file_prefix = cache_file_ext = None
3061+
3062+
def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset:
30493063
"""Load a processed shard from cache if it exists, otherwise throw an error."""
30503064
shard = shard_kwargs["shard"]
30513065
# Check if we've already cached this computation (indexed by a hash)
@@ -3056,64 +3070,71 @@ def load_processed_shard_from_cache(shard_kwargs):
30563070
return Dataset.from_file(shard_kwargs["cache_file_name"], info=info, split=shard.split)
30573071
raise NonExistentDatasetError
30583072

3059-
num_shards = num_proc if num_proc is not None else 1
3060-
if batched and drop_last_batch:
3061-
pbar_total = len(self) // num_shards // batch_size * num_shards * batch_size
3062-
else:
3063-
pbar_total = len(self)
3073+
existing_cache_file_map: dict[int, list[str]] = defaultdict(list)
3074+
if cache_file_name is not None:
3075+
if os.path.exists(cache_file_name):
3076+
existing_cache_file_map[1] = [cache_file_name]
30643077

3065-
shards_done = 0
3066-
if num_proc is None or num_proc == 1:
3067-
transformed_dataset = None
3068-
try:
3069-
transformed_dataset = load_processed_shard_from_cache(dataset_kwargs)
3070-
logger.info(f"Loading cached processed dataset at {dataset_kwargs['cache_file_name']}")
3071-
except NonExistentDatasetError:
3072-
pass
3073-
if transformed_dataset is None:
3074-
with hf_tqdm(
3075-
unit=" examples",
3076-
total=pbar_total,
3077-
desc=desc or "Map",
3078-
) as pbar:
3079-
for rank, done, content in Dataset._map_single(**dataset_kwargs):
3080-
if done:
3081-
shards_done += 1
3082-
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
3083-
transformed_dataset = content
3084-
else:
3085-
pbar.update(content)
3086-
assert transformed_dataset is not None, "Failed to retrieve the result from map"
3087-
# update fingerprint if the dataset changed
3088-
if transformed_dataset._fingerprint != self._fingerprint:
3089-
transformed_dataset._fingerprint = new_fingerprint
3090-
return transformed_dataset
3091-
else:
3078+
assert cache_file_prefix is not None and cache_file_ext is not None
3079+
cache_file_with_suffix_pattern = cache_file_prefix + suffix_template + cache_file_ext
30923080

3093-
def format_cache_file_name(
3094-
cache_file_name: Optional[str],
3095-
rank: Union[int, Literal["*"]], # noqa: F722
3096-
) -> Optional[str]:
3097-
if not cache_file_name:
3098-
return cache_file_name
3099-
sep = cache_file_name.rindex(".")
3100-
base_name, extension = cache_file_name[:sep], cache_file_name[sep:]
3101-
if isinstance(rank, int):
3102-
cache_file_name = base_name + suffix_template.format(rank=rank, num_proc=num_proc) + extension
3103-
logger.info(f"Process #{rank} will write at {cache_file_name}")
3104-
else:
3105-
cache_file_name = (
3106-
base_name
3107-
+ suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_proc)
3108-
+ extension
3109-
)
3081+
for cache_file in glob.iglob(f"{cache_file_prefix}*{cache_file_ext}"):
3082+
suffix_variable_map = string_to_dict(cache_file, cache_file_with_suffix_pattern)
3083+
if suffix_variable_map is not None:
3084+
file_num_proc = int(suffix_variable_map["num_proc"])
3085+
existing_cache_file_map[file_num_proc].append(cache_file)
3086+
3087+
num_shards = num_proc or 1
3088+
if existing_cache_file_map:
3089+
# to avoid remapping when a different num_proc is given than when originally cached, update num_shards to
3090+
# what was used originally
3091+
3092+
def select_existing_cache_files(mapped_num_proc: int) -> tuple[float, ...]:
3093+
percent_missing = (mapped_num_proc - len(existing_cache_file_map[mapped_num_proc])) / mapped_num_proc
3094+
num_shards_diff = abs(mapped_num_proc - num_shards)
3095+
return (
3096+
percent_missing, # choose the most complete set of existing cache files
3097+
num_shards_diff, # then choose the mapped_num_proc closest to the current num_proc
3098+
mapped_num_proc, # finally, choose whichever mapped_num_proc is lower
3099+
)
3100+
3101+
num_shards = min(existing_cache_file_map, key=select_existing_cache_files)
3102+
3103+
existing_cache_files = existing_cache_file_map[num_shards]
3104+
3105+
def format_cache_file_name(
3106+
cache_file_name: Optional[str],
3107+
rank: Union[int, Literal["*"]], # noqa: F722
3108+
) -> Optional[str]:
3109+
if not cache_file_name:
31103110
return cache_file_name
31113111

3112-
def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
3113-
new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_proc)
3114-
validate_fingerprint(new_fingerprint)
3115-
return new_fingerprint
3112+
assert cache_file_prefix is not None and cache_file_ext is not None
3113+
3114+
if isinstance(rank, int):
3115+
cache_file_name = (
3116+
cache_file_prefix + suffix_template.format(rank=rank, num_proc=num_shards) + cache_file_ext
3117+
)
3118+
if not os.path.exists(cache_file_name):
3119+
process_name = (
3120+
"Main process" if num_proc is None or num_proc == 1 else f"Process #{rank % num_shards + 1}"
3121+
)
3122+
logger.info(f"{process_name} will write at {cache_file_name}")
3123+
else:
3124+
# TODO: this assumes the format_spec of rank in suffix_template
3125+
cache_file_name = (
3126+
cache_file_prefix
3127+
+ suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_shards)
3128+
+ cache_file_ext
3129+
)
3130+
return cache_file_name
3131+
3132+
def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
3133+
new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_shards)
3134+
validate_fingerprint(new_fingerprint)
3135+
return new_fingerprint
31163136

3137+
if num_proc is not None and num_proc > 1:
31173138
prev_env = deepcopy(os.environ)
31183139
# check if parallelism if off
31193140
# from https://github.com/huggingface/tokenizers/blob/bb668bc439dc34389b71dbb8ce0c597f15707b53/tokenizers/src/utils/parallelism.rs#L22
@@ -3128,9 +3149,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31283149
):
31293150
logger.warning("Setting TOKENIZERS_PARALLELISM=false for forked processes.")
31303151
os.environ["TOKENIZERS_PARALLELISM"] = "false"
3152+
else:
3153+
prev_env = os.environ
3154+
3155+
kwargs_per_job: list[Optional[dict[str, Any]]]
3156+
if num_shards == 1:
3157+
shards = [self]
3158+
kwargs_per_job = [dataset_kwargs]
3159+
else:
31313160
shards = [
3132-
self.shard(num_shards=num_proc, index=rank, contiguous=True, keep_in_memory=keep_in_memory)
3133-
for rank in range(num_proc)
3161+
self.shard(num_shards=num_shards, index=rank, contiguous=True, keep_in_memory=keep_in_memory)
3162+
for rank in range(num_shards)
31343163
]
31353164
kwargs_per_job = [
31363165
{
@@ -3144,62 +3173,97 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31443173
for rank in range(num_shards)
31453174
]
31463175

3147-
transformed_shards = [None] * num_shards
3148-
for rank in range(num_shards):
3149-
try:
3150-
transformed_shards[rank] = load_processed_shard_from_cache(kwargs_per_job[rank])
3151-
kwargs_per_job[rank] = None
3152-
except NonExistentDatasetError:
3153-
pass
3154-
3155-
kwargs_per_job = [kwargs for kwargs in kwargs_per_job if kwargs is not None]
3156-
3157-
# We try to create a pool with as many workers as dataset not yet cached.
3158-
if kwargs_per_job:
3159-
if len(kwargs_per_job) < num_shards:
3160-
logger.info(
3161-
f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache."
3162-
)
3163-
with Pool(len(kwargs_per_job)) as pool:
3164-
os.environ = prev_env
3165-
logger.info(f"Spawning {num_proc} processes")
3166-
with hf_tqdm(
3167-
unit=" examples",
3168-
total=pbar_total,
3169-
desc=(desc or "Map") + f" (num_proc={num_proc})",
3170-
) as pbar:
3176+
transformed_shards: list[Optional[Dataset]] = [None] * num_shards
3177+
for rank in range(num_shards):
3178+
try:
3179+
job_kwargs = kwargs_per_job[rank]
3180+
assert job_kwargs is not None
3181+
transformed_shards[rank] = load_processed_shard_from_cache(job_kwargs)
3182+
kwargs_per_job[rank] = None
3183+
except NonExistentDatasetError:
3184+
pass
3185+
3186+
if unprocessed_kwargs_per_job := [kwargs for kwargs in kwargs_per_job if kwargs is not None]:
3187+
if len(unprocessed_kwargs_per_job) != num_shards:
3188+
logger.info(
3189+
f"Reprocessing {len(unprocessed_kwargs_per_job)}/{num_shards} shards because some of them were "
3190+
"missing from the cache."
3191+
)
3192+
3193+
pbar_total = len(self)
3194+
pbar_initial = len(existing_cache_files) * pbar_total // num_shards
3195+
if batched and drop_last_batch:
3196+
batch_size = batch_size or 1
3197+
pbar_initial = pbar_initial // num_shards // batch_size * num_shards * batch_size
3198+
pbar_total = pbar_total // num_shards // batch_size * num_shards * batch_size
3199+
3200+
with hf_tqdm(
3201+
unit=" examples",
3202+
initial=pbar_initial,
3203+
total=pbar_total,
3204+
desc=(desc or "Map") + (f" (num_proc={num_proc})" if num_proc is not None and num_proc > 1 else ""),
3205+
) as pbar:
3206+
shards_done = 0
3207+
3208+
def check_if_shard_done(rank: Optional[int], done: bool, content: Union[Dataset, int]) -> None:
3209+
nonlocal shards_done
3210+
if done:
3211+
shards_done += 1
3212+
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
3213+
assert isinstance(content, Dataset)
3214+
transformed_shards[rank or 0] = content
3215+
else:
3216+
assert isinstance(content, int)
3217+
pbar.update(content)
3218+
3219+
if num_proc is not None and num_proc > 1:
3220+
with Pool(num_proc) as pool:
3221+
os.environ = prev_env
3222+
logger.info(f"Spawning {num_proc} processes")
3223+
31713224
for rank, done, content in iflatmap_unordered(
3172-
pool, Dataset._map_single, kwargs_iterable=kwargs_per_job
3225+
pool, Dataset._map_single, kwargs_iterable=unprocessed_kwargs_per_job
31733226
):
3174-
if done:
3175-
shards_done += 1
3176-
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
3177-
transformed_shards[rank] = content
3178-
else:
3179-
pbar.update(content)
3180-
pool.close()
3181-
pool.join()
3182-
# Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
3183-
for kwargs in kwargs_per_job:
3184-
del kwargs["shard"]
3185-
else:
3186-
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")
3187-
if None in transformed_shards:
3188-
raise ValueError(
3189-
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at "
3190-
"least one worker failed to return its results"
3191-
)
3192-
logger.info(f"Concatenating {num_proc} shards")
3193-
result = _concatenate_map_style_datasets(transformed_shards)
3194-
# update fingerprint if the dataset changed
3227+
check_if_shard_done(rank, done, content)
3228+
3229+
pool.close()
3230+
pool.join()
3231+
else:
3232+
for unprocessed_kwargs in unprocessed_kwargs_per_job:
3233+
for rank, done, content in Dataset._map_single(**unprocessed_kwargs):
3234+
check_if_shard_done(rank, done, content)
3235+
3236+
# Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
3237+
for job_kwargs in unprocessed_kwargs_per_job:
3238+
if "shard" in job_kwargs:
3239+
del job_kwargs["shard"]
3240+
else:
3241+
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")
3242+
3243+
all_transformed_shards = [shard for shard in transformed_shards if shard is not None]
3244+
if len(transformed_shards) != len(all_transformed_shards):
3245+
raise ValueError(
3246+
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - "
3247+
"at least one worker failed to return its results"
3248+
)
3249+
3250+
if num_shards == 1:
3251+
result = all_transformed_shards[0]
3252+
else:
3253+
logger.info(f"Concatenating {num_shards} shards")
3254+
result = _concatenate_map_style_datasets(all_transformed_shards)
3255+
3256+
# update fingerprint if the dataset changed
3257+
result._fingerprint = (
3258+
new_fingerprint
31953259
if any(
31963260
transformed_shard._fingerprint != shard._fingerprint
3197-
for transformed_shard, shard in zip(transformed_shards, shards)
3198-
):
3199-
result._fingerprint = new_fingerprint
3200-
else:
3201-
result._fingerprint = self._fingerprint
3202-
return result
3261+
for transformed_shard, shard in zip(all_transformed_shards, shards)
3262+
)
3263+
else self._fingerprint
3264+
)
3265+
3266+
return result
32033267

32043268
@staticmethod
32053269
def _map_single(
@@ -3222,7 +3286,7 @@ def _map_single(
32223286
rank: Optional[int] = None,
32233287
offset: int = 0,
32243288
try_original_type: Optional[bool] = True,
3225-
) -> Iterable[tuple[int, bool, Union[int, "Dataset"]]]:
3289+
) -> Iterable[tuple[Optional[int], bool, Union[int, "Dataset"]]]:
32263290
"""Apply a function to all the elements in the table (individually or in batches)
32273291
and update the table (if function does update examples).
32283292
@@ -5762,7 +5826,7 @@ def push_to_hub(
57625826
@transmit_format
57635827
@fingerprint_transform(inplace=False)
57645828
def add_column(
5765-
self, name: str, column: Union[list, np.array], new_fingerprint: str, feature: Optional[FeatureType] = None
5829+
self, name: str, column: Union[list, np.ndarray], new_fingerprint: str, feature: Optional[FeatureType] = None
57665830
):
57675831
"""Add column to Dataset.
57685832

src/datasets/data_files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .utils import logging
1919
from .utils import tqdm as hf_tqdm
2020
from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin
21-
from .utils.py_utils import glob_pattern_to_regex, string_to_dict
21+
from .utils.py_utils import string_to_dict
2222

2323

2424
SingleOriginMetadata = Union[tuple[str, str], tuple[str], tuple[()]]
@@ -265,7 +265,7 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> di
265265
if len(data_files) > 0:
266266
splits: set[str] = set()
267267
for p in data_files:
268-
p_parts = string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern)))
268+
p_parts = string_to_dict(xbasename(p), xbasename(split_pattern))
269269
assert p_parts is not None
270270
splits.add(p_parts["split"])
271271

src/datasets/utils/py_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def string_to_dict(string: str, pattern: str) -> Optional[dict[str, str]]:
180180
Optional[dict[str, str]]: dictionary of variable -> value, retrieved from the input using the pattern, or
181181
`None` if the string does not match the pattern.
182182
"""
183+
pattern = pattern.encode("unicode_escape").decode("utf-8") # C:\\Users -> C:\\\\Users for Windows paths
183184
pattern = re.sub(r"{([^:}]+)(?::[^}]+)?}", r"{\1}", pattern) # remove format specifiers, e.g. {rank:05d} -> {rank}
184185
regex = re.sub(r"{(.+?)}", r"(?P<_\1>.+)", pattern)
185186
result = re.search(regex, string)

0 commit comments

Comments
 (0)