Skip to content

Commit a2dc287

Browse files
authored
Remove default trust_remote_code=True (#6954)
* removev default trust_remote_code=True * fix tests * fix tests * again * again * style
1 parent 09ebf51 commit a2dc287

File tree

12 files changed

+101
-54
lines changed

12 files changed

+101
-54
lines changed

src/datasets/commands/test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from argparse import ArgumentParser
44
from pathlib import Path
55
from shutil import copyfile, rmtree
6-
from typing import Generator
6+
from typing import Generator, Optional
77

88
import datasets.config
99
from datasets.builder import DatasetBuilder
@@ -29,6 +29,7 @@ def _test_command_factory(args):
2929
args.force_redownload,
3030
args.clear_cache,
3131
args.num_proc,
32+
args.trust_remote_code,
3233
)
3334

3435

@@ -67,6 +68,9 @@ def register_subcommand(parser: ArgumentParser):
6768
help="Remove downloaded files and cached datasets after each config test",
6869
)
6970
test_parser.add_argument("--num_proc", type=int, default=None, help="Number of processes")
71+
test_parser.add_argument(
72+
"--trust_remote_code", action="store_true", help="whether to trust the code execution of the load script"
73+
)
7074
# aliases
7175
test_parser.add_argument("--save_infos", action="store_true", help="alias to save_info")
7276
test_parser.add_argument("dataset", type=str, help="Name of the dataset to download")
@@ -84,6 +88,7 @@ def __init__(
8488
force_redownload: bool,
8589
clear_cache: bool,
8690
num_proc: int,
91+
trust_remote_code: Optional[bool],
8792
):
8893
self._dataset = dataset
8994
self._name = name
@@ -95,6 +100,7 @@ def __init__(
95100
self._force_redownload = force_redownload
96101
self._clear_cache = clear_cache
97102
self._num_proc = num_proc
103+
self._trust_remote_code = trust_remote_code
98104
if clear_cache and not cache_dir:
99105
print(
100106
"When --clear_cache is used, specifying a cache directory is mandatory.\n"
@@ -111,7 +117,7 @@ def run(self):
111117
print("Both parameters `config` and `all_configs` can't be used at once.")
112118
exit(1)
113119
path, config_name = self._dataset, self._name
114-
module = dataset_module_factory(path)
120+
module = dataset_module_factory(path, trust_remote_code=self._trust_remote_code)
115121
builder_cls = import_main_class(module.module_path)
116122
n_builders = len(builder_cls.BUILDER_CONFIGS) if self._all_configs and builder_cls.BUILDER_CONFIGS else 1
117123

src/datasets/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@
186186
HF_DATASETS_MULTITHREADING_MAX_WORKERS = 16
187187

188188
# Remote dataset scripts support
189-
__HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "1")
189+
__HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "ask")
190190
HF_DATASETS_TRUST_REMOTE_CODE: Optional[bool] = (
191191
True
192192
if __HF_DATASETS_TRUST_REMOTE_CODE.upper() in ENV_VARS_TRUE_VALUES

src/datasets/load.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2007,7 +2007,10 @@ def metric_module_factory(
20072007
raise FileNotFoundError(f"Couldn't find a metric script at {relative_to_absolute_path(path)}")
20082008
elif os.path.isfile(combined_path):
20092009
return LocalMetricModuleFactory(
2010-
combined_path, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path
2010+
combined_path,
2011+
download_mode=download_mode,
2012+
dynamic_modules_path=dynamic_modules_path,
2013+
trust_remote_code=trust_remote_code,
20112014
).get_module()
20122015
elif is_relative_path(path) and path.count("/") == 0:
20132016
try:

tests/commands/test_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
"force_redownload",
2222
"clear_cache",
2323
"num_proc",
24+
"trust_remote_code",
2425
],
25-
defaults=[None, None, None, False, False, False, False, False, None],
26+
defaults=[None, None, None, False, False, False, False, False, None, None],
2627
)
2728

2829

@@ -32,7 +33,9 @@ def is_1percent_close(source, target):
3233

3334
@pytest.mark.integration
3435
def test_test_command(dataset_loading_script_dir):
35-
args = _TestCommandArgs(dataset=dataset_loading_script_dir, all_configs=True, save_infos=True)
36+
args = _TestCommandArgs(
37+
dataset=dataset_loading_script_dir, all_configs=True, save_infos=True, trust_remote_code=True
38+
)
3639
test_command = TestCommand(*args)
3740
test_command.run()
3841
dataset_readme_path = os.path.join(dataset_loading_script_dir, "README.md")

tests/features/test_audio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,9 @@ def test_load_dataset_with_audio_feature(streaming, jsonl_audio_dataset_path, sh
604604
@pytest.mark.integration
605605
def test_dataset_with_audio_feature_loaded_from_cache():
606606
# load first time
607-
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
607+
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", trust_remote_code=True)
608608
# load from cache
609-
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
609+
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", trust_remote_code=True, split="validation")
610610
assert isinstance(ds, Dataset)
611611

612612

tests/features/test_image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,9 @@ def test_load_dataset_with_image_feature(shared_datadir, data_dir, dataset_loadi
614614
import PIL.Image
615615

616616
image_path = str(shared_datadir / "test_image_rgb.jpg")
617-
dset = load_dataset(dataset_loading_script_dir, split="train", data_dir=data_dir, streaming=streaming)
617+
dset = load_dataset(
618+
dataset_loading_script_dir, split="train", data_dir=data_dir, streaming=streaming, trust_remote_code=True
619+
)
618620
item = dset[0] if not streaming else next(iter(dset))
619621
assert item.keys() == {"image", "caption"}
620622
assert isinstance(item["image"], PIL.Image.Image)

tests/test_hf_gcp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_dataset_info_available(self, dataset, config_name, revision):
7272
config_name,
7373
revision=revision,
7474
cache_dir=tmp_dir,
75+
trust_remote_code=True,
7576
)
7677

7778
dataset_info_url = "/".join(
@@ -88,7 +89,7 @@ def test_dataset_info_available(self, dataset, config_name, revision):
8889
@pytest.mark.integration
8990
def test_as_dataset_from_hf_gcs(tmp_path_factory):
9091
tmp_dir = tmp_path_factory.mktemp("test_hf_gcp") / "test_wikipedia_simple"
91-
builder = load_dataset_builder("wikipedia", "20220301.frr", cache_dir=tmp_dir)
92+
builder = load_dataset_builder("wikipedia", "20220301.frr", cache_dir=tmp_dir, trust_remote_code=True)
9293
# use the HF cloud storage, not the original download_and_prepare that uses apache-beam
9394
builder._download_and_prepare = None
9495
builder.download_and_prepare(try_from_hf_gcs=True)
@@ -99,7 +100,11 @@ def test_as_dataset_from_hf_gcs(tmp_path_factory):
99100
@pytest.mark.integration
100101
def test_as_streaming_dataset_from_hf_gcs(tmp_path):
101102
builder = load_dataset_builder(
102-
"wikipedia", "20220301.frr", revision="4d013bdd32c475c8536aae00a56efc774f061649", cache_dir=tmp_path
103+
"wikipedia",
104+
"20220301.frr",
105+
revision="4d013bdd32c475c8536aae00a56efc774f061649",
106+
cache_dir=tmp_path,
107+
trust_remote_code=True,
103108
)
104109
ds = builder.as_streaming_dataset()
105110
assert ds

tests/test_hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_
6262
with patch.object(datasets.hub.HfApi, "create_commit", return_value=commit_info) as mock_create_commit:
6363
with patch.object(datasets.hub.HfApi, "create_branch") as mock_create_branch:
6464
with patch.object(datasets.hub.HfApi, "list_repo_tree", return_value=[]): # not needed
65-
_ = convert_to_parquet(repo_id, token=hf_token)
65+
_ = convert_to_parquet(repo_id, token=hf_token, trust_remote_code=True)
6666
# mock_create_branch
6767
assert mock_create_branch.called
6868
assert mock_create_branch.call_count == 2

tests/test_inspect.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_inspect_dataset(path, tmp_path):
2929
@pytest.mark.filterwarnings("ignore:metric_module_factory is deprecated:FutureWarning")
3030
@pytest.mark.parametrize("path", ["accuracy"])
3131
def test_inspect_metric(path, tmp_path):
32-
inspect_metric(path, tmp_path)
32+
inspect_metric(path, tmp_path, trust_remote_code=True)
3333
script_name = path + ".py"
3434
assert script_name in os.listdir(tmp_path)
3535
assert "__pycache__" not in os.listdir(tmp_path)
@@ -79,7 +79,7 @@ def test_get_dataset_config_info_error(path, config_name, expected_exception):
7979
],
8080
)
8181
def test_get_dataset_config_names(path, expected):
82-
config_names = get_dataset_config_names(path)
82+
config_names = get_dataset_config_names(path, trust_remote_code=True)
8383
assert config_names == expected
8484

8585

@@ -97,7 +97,7 @@ def test_get_dataset_config_names(path, expected):
9797
],
9898
)
9999
def test_get_dataset_default_config_name(path, expected):
100-
default_config_name = get_dataset_default_config_name(path)
100+
default_config_name = get_dataset_default_config_name(path, trust_remote_code=True)
101101
if expected:
102102
assert default_config_name == expected
103103
else:

0 commit comments

Comments
 (0)