|
| 1 | +import contextvars |
1 | 2 | import dataclasses |
2 | 3 | import re |
3 | 4 | import unittest.mock |
|
21 | 22 | import pydantic |
22 | 23 | import pydantic_core |
23 | 24 | from cached_property import cached_property |
24 | | -from pydantic import BaseModel, ConfigDict, ValidationError |
| 25 | +from pydantic import BaseModel, ConfigDict, SecretStr, ValidationError, model_validator |
25 | 26 | from pydantic.fields import Field |
26 | 27 | from typing_extensions import Protocol, Self |
27 | 28 |
|
28 | 29 | from datahub.configuration._config_enum import ConfigEnum as ConfigEnum |
| 30 | +from datahub.masking.secret_registry import SecretRegistry, is_masking_enabled |
29 | 31 | from datahub.utilities.dedup_list import deduplicate_list |
30 | 32 |
|
31 | 33 | REDACT_KEYS = { |
@@ -95,6 +97,11 @@ def redact_raw_config(obj: Any) -> Any: |
95 | 97 |
|
96 | 98 | LaxStr = Annotated[str, pydantic.BeforeValidator(lambda v: str(v))] |
97 | 99 |
|
| 100 | +# Context variable to track if we're inside a nested ConfigModel construction |
| 101 | +_inside_nested_config: contextvars.ContextVar[bool] = contextvars.ContextVar( |
| 102 | + "_inside_nested_config", default=False |
| 103 | +) |
| 104 | + |
98 | 105 |
|
99 | 106 | @dataclasses.dataclass(frozen=True) |
100 | 107 | class SupportedSources: |
@@ -129,6 +136,96 @@ class ConfigModel(BaseModel): |
129 | 136 | json_schema_extra=_config_model_schema_extra, |
130 | 137 | ) |
131 | 138 |
|
| 139 | + @model_validator(mode="wrap") |
| 140 | + @classmethod |
| 141 | + def _track_nesting_context( |
| 142 | + cls, |
| 143 | + data: Any, |
| 144 | + handler: pydantic.ValidatorFunctionWrapHandler, |
| 145 | + info: pydantic.ValidationInfo, |
| 146 | + ) -> Self: |
| 147 | + """ |
| 148 | + Wrap validator that tracks nesting context for nested ConfigModel detection. |
| 149 | +
|
| 150 | + Sets a context variable so nested ConfigModels know they're being constructed as fields. |
| 151 | + """ |
| 152 | + # Set context for any nested models that will be created during field processing |
| 153 | + token = _inside_nested_config.set(True) |
| 154 | + try: |
| 155 | + # Process the model normally (this calls __init__ and all validators) |
| 156 | + instance = handler(data) |
| 157 | + finally: |
| 158 | + # Reset context after processing |
| 159 | + _inside_nested_config.reset(token) |
| 160 | + |
| 161 | + return instance |
| 162 | + |
| 163 | + @model_validator(mode="after") |
| 164 | + def _register_secret_fields(self) -> Self: |
| 165 | + """ |
| 166 | + Register SecretStr fields with the secret masking registry. |
| 167 | + Recursively traverses nested ConfigModel instances to find all SecretStr fields. |
| 168 | +
|
| 169 | + Only models that are constructed outside of Pydantic field processing will register secrets. |
| 170 | + This ensures we capture the full qualified paths for nested secrets. |
| 171 | +
|
| 172 | + Performance: Uses batch registration for efficiency - single version |
| 173 | + increment instead of one per secret. |
| 174 | + """ |
| 175 | + if not is_masking_enabled(): |
| 176 | + return self |
| 177 | + |
| 178 | + # Only register if we're NOT inside another ConfigModel's field processing |
| 179 | + # This means we're a "root" model from the user's perspective |
| 180 | + if _inside_nested_config.get(): |
| 181 | + return self |
| 182 | + |
| 183 | + # Collect all secrets recursively (including from nested models) |
| 184 | + secrets: Dict[str, str] = {} |
| 185 | + self._collect_secrets(secrets, prefix="") |
| 186 | + |
| 187 | + # Batch register all secrets in one operation |
| 188 | + if secrets: |
| 189 | + SecretRegistry.get_instance().register_secrets_batch(secrets) |
| 190 | + |
| 191 | + return self |
| 192 | + |
| 193 | + def _collect_secrets(self, secrets: Dict[str, str], prefix: str) -> None: |
| 194 | + """ |
| 195 | + Recursively collect SecretStr fields from this model and nested ConfigModel instances. |
| 196 | +
|
| 197 | + Args: |
| 198 | + secrets: Dictionary to populate with field_name -> secret_value mappings |
| 199 | + prefix: Prefix for nested field names (e.g., "azure_auth." for nested fields) |
| 200 | + """ |
| 201 | + for field_name, _field_info in self.__class__.model_fields.items(): |
| 202 | + field_value = getattr(self, field_name, None) |
| 203 | + |
| 204 | + if field_value is None: |
| 205 | + continue |
| 206 | + |
| 207 | + # Build the full field path for better debugging |
| 208 | + full_name = f"{prefix}{field_name}" if prefix else field_name |
| 209 | + |
| 210 | + if isinstance(field_value, SecretStr): |
| 211 | + # Direct SecretStr field |
| 212 | + secret_value = field_value.get_secret_value() |
| 213 | + if secret_value: |
| 214 | + secrets[full_name] = secret_value |
| 215 | + elif isinstance(field_value, ConfigModel): |
| 216 | + # Nested ConfigModel - recurse into it |
| 217 | + field_value._collect_secrets(secrets, prefix=f"{full_name}.") |
| 218 | + elif isinstance(field_value, list): |
| 219 | + # Handle lists of ConfigModels |
| 220 | + for idx, item in enumerate(field_value): |
| 221 | + if isinstance(item, ConfigModel): |
| 222 | + item._collect_secrets(secrets, prefix=f"{full_name}[{idx}].") |
| 223 | + elif isinstance(field_value, dict): |
| 224 | + # Handle dicts with ConfigModel values |
| 225 | + for key, item in field_value.items(): |
| 226 | + if isinstance(item, ConfigModel): |
| 227 | + item._collect_secrets(secrets, prefix=f"{full_name}[{key}].") |
| 228 | + |
132 | 229 | @classmethod |
133 | 230 | def parse_obj_allow_extras(cls, obj: Any) -> Self: |
134 | 231 | """Parse an object while allowing extra fields. |
|
0 commit comments