Skip to content

Commit b2da59b

Browse files
authored
[Modular] Provide option to disable custom code loading globally via env variable (#12177)
* update * update * update * update
1 parent 7aa6af1 commit b2da59b

File tree

3 files changed

+19
-45
lines changed

3 files changed

+19
-45
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def outputs(self) -> List[OutputParam]:
299299
def from_pretrained(
300300
cls,
301301
pretrained_model_name_or_path: str,
302-
trust_remote_code: Optional[bool] = None,
302+
trust_remote_code: bool = False,
303303
**kwargs,
304304
):
305305
hub_kwargs_names = [

src/diffusers/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
4747
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
48+
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
4849

4950
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
5051
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import os
2121
import re
2222
import shutil
23-
import signal
2423
import sys
2524
import threading
2625
from pathlib import Path
@@ -34,6 +33,7 @@
3433

3534
from .. import __version__
3635
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
36+
from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
3737

3838

3939
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -159,52 +159,25 @@ def check_imports(filename):
159159
return get_relative_imports(filename)
160160

161161

162-
def _raise_timeout_error(signum, frame):
163-
raise ValueError(
164-
"Loading this model requires you to execute custom code contained in the model repository on your local "
165-
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
166-
)
167-
168-
169162
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
170-
if trust_remote_code is None:
171-
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
172-
prev_sig_handler = None
173-
try:
174-
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
175-
signal.alarm(TIME_OUT_REMOTE_CODE)
176-
while trust_remote_code is None:
177-
answer = input(
178-
f"The repository for {model_name} contains custom code which must be executed to correctly "
179-
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
180-
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
181-
f"Do you wish to run the custom code? [y/N] "
182-
)
183-
if answer.lower() in ["yes", "y", "1"]:
184-
trust_remote_code = True
185-
elif answer.lower() in ["no", "n", "0", ""]:
186-
trust_remote_code = False
187-
signal.alarm(0)
188-
except Exception:
189-
# OS which does not support signal.SIGALRM
190-
raise ValueError(
191-
f"The repository for {model_name} contains custom code which must be executed to correctly "
192-
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
193-
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
194-
)
195-
finally:
196-
if prev_sig_handler is not None:
197-
signal.signal(signal.SIGALRM, prev_sig_handler)
198-
signal.alarm(0)
199-
elif has_remote_code:
200-
# For the CI which puts the timeout at 0
201-
_raise_timeout_error(None, None)
163+
trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
164+
if DIFFUSERS_DISABLE_REMOTE_CODE:
165+
logger.warning(
166+
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
167+
)
202168

203169
if has_remote_code and not trust_remote_code:
204-
raise ValueError(
205-
f"Loading {model_name} requires you to execute the configuration file in that"
206-
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
207-
" set the option `trust_remote_code=True` to remove this error."
170+
error_msg = f"The repository for {model_name} contains custom code. "
171+
error_msg += (
172+
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
173+
if DIFFUSERS_DISABLE_REMOTE_CODE
174+
else "Pass `trust_remote_code=True` to allow loading remote code modules."
175+
)
176+
raise ValueError(error_msg)
177+
178+
elif has_remote_code and trust_remote_code:
179+
logger.warning(
180+
f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
208181
)
209182

210183
return trust_remote_code

0 commit comments

Comments
 (0)