19
19
import contextlib
20
20
import copy
21
21
import fnmatch
22
+ import glob
22
23
import inspect
23
24
import itertools
24
25
import json
27
28
import posixpath
28
29
import re
29
30
import shutil
31
+ import string
30
32
import sys
31
33
import tempfile
32
34
import time
33
35
import warnings
34
36
import weakref
35
- from collections import Counter
37
+ from collections import Counter , defaultdict
36
38
from collections .abc import Iterable , Iterator , Mapping
37
39
from collections .abc import Sequence as Sequence_
38
40
from copy import deepcopy
@@ -2963,6 +2965,11 @@ def map(
2963
2965
if num_proc is not None and num_proc <= 0 :
2964
2966
raise ValueError ("num_proc must be an integer > 0." )
2965
2967
2968
+ string_formatter = string .Formatter ()
2969
+ fields = {field_name for _ , field_name , _ , _ in string_formatter .parse (suffix_template ) if field_name }
2970
+ if fields != {"rank" , "num_proc" }:
2971
+ raise ValueError (f"suffix_template must contain exactly the fields 'rank' and 'num_proc', got: { fields } " )
2972
+
2966
2973
# If the array is empty we do nothing (but we make sure to handle an empty indices mapping and remove the requested columns anyway)
2967
2974
if len (self ) == 0 :
2968
2975
if self ._indices is not None : # empty indices mapping
@@ -3045,7 +3052,14 @@ def map(
3045
3052
cache_file_name = self ._get_cache_file_path (new_fingerprint )
3046
3053
dataset_kwargs ["cache_file_name" ] = cache_file_name
3047
3054
3048
- def load_processed_shard_from_cache (shard_kwargs ):
3055
+ if cache_file_name is not None :
3056
+ cache_file_prefix , cache_file_ext = os .path .splitext (cache_file_name )
3057
+ if not cache_file_ext :
3058
+ raise ValueError (f"Expected cache_file_name to have an extension, but got: { cache_file_name } " )
3059
+ else :
3060
+ cache_file_prefix = cache_file_ext = None
3061
+
3062
+ def load_processed_shard_from_cache (shard_kwargs : dict [str , Any ]) -> Dataset :
3049
3063
"""Load a processed shard from cache if it exists, otherwise throw an error."""
3050
3064
shard = shard_kwargs ["shard" ]
3051
3065
# Check if we've already cached this computation (indexed by a hash)
@@ -3056,64 +3070,71 @@ def load_processed_shard_from_cache(shard_kwargs):
3056
3070
return Dataset .from_file (shard_kwargs ["cache_file_name" ], info = info , split = shard .split )
3057
3071
raise NonExistentDatasetError
3058
3072
3059
- num_shards = num_proc if num_proc is not None else 1
3060
- if batched and drop_last_batch :
3061
- pbar_total = len (self ) // num_shards // batch_size * num_shards * batch_size
3062
- else :
3063
- pbar_total = len (self )
3073
+ existing_cache_file_map : dict [int , list [str ]] = defaultdict (list )
3074
+ if cache_file_name is not None :
3075
+ if os .path .exists (cache_file_name ):
3076
+ existing_cache_file_map [1 ] = [cache_file_name ]
3064
3077
3065
- shards_done = 0
3066
- if num_proc is None or num_proc == 1 :
3067
- transformed_dataset = None
3068
- try :
3069
- transformed_dataset = load_processed_shard_from_cache (dataset_kwargs )
3070
- logger .info (f"Loading cached processed dataset at { dataset_kwargs ['cache_file_name' ]} " )
3071
- except NonExistentDatasetError :
3072
- pass
3073
- if transformed_dataset is None :
3074
- with hf_tqdm (
3075
- unit = " examples" ,
3076
- total = pbar_total ,
3077
- desc = desc or "Map" ,
3078
- ) as pbar :
3079
- for rank , done , content in Dataset ._map_single (** dataset_kwargs ):
3080
- if done :
3081
- shards_done += 1
3082
- logger .debug (f"Finished processing shard number { rank } of { num_shards } ." )
3083
- transformed_dataset = content
3084
- else :
3085
- pbar .update (content )
3086
- assert transformed_dataset is not None , "Failed to retrieve the result from map"
3087
- # update fingerprint if the dataset changed
3088
- if transformed_dataset ._fingerprint != self ._fingerprint :
3089
- transformed_dataset ._fingerprint = new_fingerprint
3090
- return transformed_dataset
3091
- else :
3078
+ assert cache_file_prefix is not None and cache_file_ext is not None
3079
+ cache_file_with_suffix_pattern = cache_file_prefix + suffix_template + cache_file_ext
3092
3080
3093
- def format_cache_file_name (
3094
- cache_file_name : Optional [str ],
3095
- rank : Union [int , Literal ["*" ]], # noqa: F722
3096
- ) -> Optional [str ]:
3097
- if not cache_file_name :
3098
- return cache_file_name
3099
- sep = cache_file_name .rindex ("." )
3100
- base_name , extension = cache_file_name [:sep ], cache_file_name [sep :]
3101
- if isinstance (rank , int ):
3102
- cache_file_name = base_name + suffix_template .format (rank = rank , num_proc = num_proc ) + extension
3103
- logger .info (f"Process #{ rank } will write at { cache_file_name } " )
3104
- else :
3105
- cache_file_name = (
3106
- base_name
3107
- + suffix_template .replace ("{rank:05d}" , "{rank}" ).format (rank = rank , num_proc = num_proc )
3108
- + extension
3109
- )
3081
+ for cache_file in glob .iglob (f"{ cache_file_prefix } *{ cache_file_ext } " ):
3082
+ suffix_variable_map = string_to_dict (cache_file , cache_file_with_suffix_pattern )
3083
+ if suffix_variable_map is not None :
3084
+ file_num_proc = int (suffix_variable_map ["num_proc" ])
3085
+ existing_cache_file_map [file_num_proc ].append (cache_file )
3086
+
3087
+ num_shards = num_proc or 1
3088
+ if existing_cache_file_map :
3089
+ # to avoid remapping when a different num_proc is given than when originally cached, update num_shards to
3090
+ # what was used originally
3091
+
3092
+ def select_existing_cache_files (mapped_num_proc : int ) -> tuple [float , ...]:
3093
+ percent_missing = (mapped_num_proc - len (existing_cache_file_map [mapped_num_proc ])) / mapped_num_proc
3094
+ num_shards_diff = abs (mapped_num_proc - num_shards )
3095
+ return (
3096
+ percent_missing , # choose the most complete set of existing cache files
3097
+ num_shards_diff , # then choose the mapped_num_proc closest to the current num_proc
3098
+ mapped_num_proc , # finally, choose whichever mapped_num_proc is lower
3099
+ )
3100
+
3101
+ num_shards = min (existing_cache_file_map , key = select_existing_cache_files )
3102
+
3103
+ existing_cache_files = existing_cache_file_map [num_shards ]
3104
+
3105
+ def format_cache_file_name (
3106
+ cache_file_name : Optional [str ],
3107
+ rank : Union [int , Literal ["*" ]], # noqa: F722
3108
+ ) -> Optional [str ]:
3109
+ if not cache_file_name :
3110
3110
return cache_file_name
3111
3111
3112
- def format_new_fingerprint (new_fingerprint : str , rank : int ) -> str :
3113
- new_fingerprint = new_fingerprint + suffix_template .format (rank = rank , num_proc = num_proc )
3114
- validate_fingerprint (new_fingerprint )
3115
- return new_fingerprint
3112
+ assert cache_file_prefix is not None and cache_file_ext is not None
3113
+
3114
+ if isinstance (rank , int ):
3115
+ cache_file_name = (
3116
+ cache_file_prefix + suffix_template .format (rank = rank , num_proc = num_shards ) + cache_file_ext
3117
+ )
3118
+ if not os .path .exists (cache_file_name ):
3119
+ process_name = (
3120
+ "Main process" if num_proc is None or num_proc == 1 else f"Process #{ rank % num_shards + 1 } "
3121
+ )
3122
+ logger .info (f"{ process_name } will write at { cache_file_name } " )
3123
+ else :
3124
+ # TODO: this assumes the format_spec of rank in suffix_template
3125
+ cache_file_name = (
3126
+ cache_file_prefix
3127
+ + suffix_template .replace ("{rank:05d}" , "{rank}" ).format (rank = rank , num_proc = num_shards )
3128
+ + cache_file_ext
3129
+ )
3130
+ return cache_file_name
3131
+
3132
+ def format_new_fingerprint (new_fingerprint : str , rank : int ) -> str :
3133
+ new_fingerprint = new_fingerprint + suffix_template .format (rank = rank , num_proc = num_shards )
3134
+ validate_fingerprint (new_fingerprint )
3135
+ return new_fingerprint
3116
3136
3137
+ if num_proc is not None and num_proc > 1 :
3117
3138
prev_env = deepcopy (os .environ )
3118
3139
# check if parallelism if off
3119
3140
# from https://github.com/huggingface/tokenizers/blob/bb668bc439dc34389b71dbb8ce0c597f15707b53/tokenizers/src/utils/parallelism.rs#L22
@@ -3128,9 +3149,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
3128
3149
):
3129
3150
logger .warning ("Setting TOKENIZERS_PARALLELISM=false for forked processes." )
3130
3151
os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
3152
+ else :
3153
+ prev_env = os .environ
3154
+
3155
+ kwargs_per_job : list [Optional [dict [str , Any ]]]
3156
+ if num_shards == 1 :
3157
+ shards = [self ]
3158
+ kwargs_per_job = [dataset_kwargs ]
3159
+ else :
3131
3160
shards = [
3132
- self .shard (num_shards = num_proc , index = rank , contiguous = True , keep_in_memory = keep_in_memory )
3133
- for rank in range (num_proc )
3161
+ self .shard (num_shards = num_shards , index = rank , contiguous = True , keep_in_memory = keep_in_memory )
3162
+ for rank in range (num_shards )
3134
3163
]
3135
3164
kwargs_per_job = [
3136
3165
{
@@ -3144,62 +3173,97 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
3144
3173
for rank in range (num_shards )
3145
3174
]
3146
3175
3147
- transformed_shards = [None ] * num_shards
3148
- for rank in range (num_shards ):
3149
- try :
3150
- transformed_shards [rank ] = load_processed_shard_from_cache (kwargs_per_job [rank ])
3151
- kwargs_per_job [rank ] = None
3152
- except NonExistentDatasetError :
3153
- pass
3154
-
3155
- kwargs_per_job = [kwargs for kwargs in kwargs_per_job if kwargs is not None ]
3156
-
3157
- # We try to create a pool with as many workers as dataset not yet cached.
3158
- if kwargs_per_job :
3159
- if len (kwargs_per_job ) < num_shards :
3160
- logger .info (
3161
- f"Reprocessing { len (kwargs_per_job )} /{ num_shards } shards because some of them were missing from the cache."
3162
- )
3163
- with Pool (len (kwargs_per_job )) as pool :
3164
- os .environ = prev_env
3165
- logger .info (f"Spawning { num_proc } processes" )
3166
- with hf_tqdm (
3167
- unit = " examples" ,
3168
- total = pbar_total ,
3169
- desc = (desc or "Map" ) + f" (num_proc={ num_proc } )" ,
3170
- ) as pbar :
3176
+ transformed_shards : list [Optional [Dataset ]] = [None ] * num_shards
3177
+ for rank in range (num_shards ):
3178
+ try :
3179
+ job_kwargs = kwargs_per_job [rank ]
3180
+ assert job_kwargs is not None
3181
+ transformed_shards [rank ] = load_processed_shard_from_cache (job_kwargs )
3182
+ kwargs_per_job [rank ] = None
3183
+ except NonExistentDatasetError :
3184
+ pass
3185
+
3186
+ if unprocessed_kwargs_per_job := [kwargs for kwargs in kwargs_per_job if kwargs is not None ]:
3187
+ if len (unprocessed_kwargs_per_job ) != num_shards :
3188
+ logger .info (
3189
+ f"Reprocessing { len (unprocessed_kwargs_per_job )} /{ num_shards } shards because some of them were "
3190
+ "missing from the cache."
3191
+ )
3192
+
3193
+ pbar_total = len (self )
3194
+ pbar_initial = len (existing_cache_files ) * pbar_total // num_shards
3195
+ if batched and drop_last_batch :
3196
+ batch_size = batch_size or 1
3197
+ pbar_initial = pbar_initial // num_shards // batch_size * num_shards * batch_size
3198
+ pbar_total = pbar_total // num_shards // batch_size * num_shards * batch_size
3199
+
3200
+ with hf_tqdm (
3201
+ unit = " examples" ,
3202
+ initial = pbar_initial ,
3203
+ total = pbar_total ,
3204
+ desc = (desc or "Map" ) + (f" (num_proc={ num_proc } )" if num_proc is not None and num_proc > 1 else "" ),
3205
+ ) as pbar :
3206
+ shards_done = 0
3207
+
3208
+ def check_if_shard_done (rank : Optional [int ], done : bool , content : Union [Dataset , int ]) -> None :
3209
+ nonlocal shards_done
3210
+ if done :
3211
+ shards_done += 1
3212
+ logger .debug (f"Finished processing shard number { rank } of { num_shards } ." )
3213
+ assert isinstance (content , Dataset )
3214
+ transformed_shards [rank or 0 ] = content
3215
+ else :
3216
+ assert isinstance (content , int )
3217
+ pbar .update (content )
3218
+
3219
+ if num_proc is not None and num_proc > 1 :
3220
+ with Pool (num_proc ) as pool :
3221
+ os .environ = prev_env
3222
+ logger .info (f"Spawning { num_proc } processes" )
3223
+
3171
3224
for rank , done , content in iflatmap_unordered (
3172
- pool , Dataset ._map_single , kwargs_iterable = kwargs_per_job
3225
+ pool , Dataset ._map_single , kwargs_iterable = unprocessed_kwargs_per_job
3173
3226
):
3174
- if done :
3175
- shards_done += 1
3176
- logger .debug (f"Finished processing shard number { rank } of { num_shards } ." )
3177
- transformed_shards [rank ] = content
3178
- else :
3179
- pbar .update (content )
3180
- pool .close ()
3181
- pool .join ()
3182
- # Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
3183
- for kwargs in kwargs_per_job :
3184
- del kwargs ["shard" ]
3185
- else :
3186
- logger .info (f"Loading cached processed dataset at { format_cache_file_name (cache_file_name , '*' )} " )
3187
- if None in transformed_shards :
3188
- raise ValueError (
3189
- f"Failed to retrieve results from map: result list { transformed_shards } still contains None - at "
3190
- "least one worker failed to return its results"
3191
- )
3192
- logger .info (f"Concatenating { num_proc } shards" )
3193
- result = _concatenate_map_style_datasets (transformed_shards )
3194
- # update fingerprint if the dataset changed
3227
+ check_if_shard_done (rank , done , content )
3228
+
3229
+ pool .close ()
3230
+ pool .join ()
3231
+ else :
3232
+ for unprocessed_kwargs in unprocessed_kwargs_per_job :
3233
+ for rank , done , content in Dataset ._map_single (** unprocessed_kwargs ):
3234
+ check_if_shard_done (rank , done , content )
3235
+
3236
+ # Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
3237
+ for job_kwargs in unprocessed_kwargs_per_job :
3238
+ if "shard" in job_kwargs :
3239
+ del job_kwargs ["shard" ]
3240
+ else :
3241
+ logger .info (f"Loading cached processed dataset at { format_cache_file_name (cache_file_name , '*' )} " )
3242
+
3243
+ all_transformed_shards = [shard for shard in transformed_shards if shard is not None ]
3244
+ if len (transformed_shards ) != len (all_transformed_shards ):
3245
+ raise ValueError (
3246
+ f"Failed to retrieve results from map: result list { transformed_shards } still contains None - "
3247
+ "at least one worker failed to return its results"
3248
+ )
3249
+
3250
+ if num_shards == 1 :
3251
+ result = all_transformed_shards [0 ]
3252
+ else :
3253
+ logger .info (f"Concatenating { num_shards } shards" )
3254
+ result = _concatenate_map_style_datasets (all_transformed_shards )
3255
+
3256
+ # update fingerprint if the dataset changed
3257
+ result ._fingerprint = (
3258
+ new_fingerprint
3195
3259
if any (
3196
3260
transformed_shard ._fingerprint != shard ._fingerprint
3197
- for transformed_shard , shard in zip (transformed_shards , shards )
3198
- ):
3199
- result ._fingerprint = new_fingerprint
3200
- else :
3201
- result . _fingerprint = self . _fingerprint
3202
- return result
3261
+ for transformed_shard , shard in zip (all_transformed_shards , shards )
3262
+ )
3263
+ else self ._fingerprint
3264
+ )
3265
+
3266
+ return result
3203
3267
3204
3268
@staticmethod
3205
3269
def _map_single (
@@ -3222,7 +3286,7 @@ def _map_single(
3222
3286
rank : Optional [int ] = None ,
3223
3287
offset : int = 0 ,
3224
3288
try_original_type : Optional [bool ] = True ,
3225
- ) -> Iterable [tuple [int , bool , Union [int , "Dataset" ]]]:
3289
+ ) -> Iterable [tuple [Optional [ int ] , bool , Union [int , "Dataset" ]]]:
3226
3290
"""Apply a function to all the elements in the table (individually or in batches)
3227
3291
and update the table (if function does update examples).
3228
3292
@@ -5762,7 +5826,7 @@ def push_to_hub(
5762
5826
@transmit_format
5763
5827
@fingerprint_transform (inplace = False )
5764
5828
def add_column (
5765
- self , name : str , column : Union [list , np .array ], new_fingerprint : str , feature : Optional [FeatureType ] = None
5829
+ self , name : str , column : Union [list , np .ndarray ], new_fingerprint : str , feature : Optional [FeatureType ] = None
5766
5830
):
5767
5831
"""Add column to Dataset.
5768
5832
0 commit comments