From 9a687964a0465ba16b6c393c7b2974e034340f77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20G=C3=B3mez=20Villamor?= Date: Mon, 10 Nov 2025 15:28:04 +0000 Subject: [PATCH 1/4] refactor(pydantic): remove Pydantic v1 legacy code and migrate fully to v2 --- metadata-ingestion/setup.cfg | 5 +- .../api/entities/assertion/assertion.py | 22 +- .../assertion/assertion_config_spec.py | 13 +- .../entities/assertion/assertion_operator.py | 69 +- .../entities/assertion/assertion_trigger.py | 40 +- .../entities/assertion/datahub_assertion.py | 62 +- .../api/entities/assertion/field_assertion.py | 53 +- .../datahub/api/entities/assertion/filter.py | 7 +- .../entities/assertion/freshness_assertion.py | 63 +- .../api/entities/assertion/sql_assertion.py | 50 +- .../entities/assertion/volume_assertion.py | 58 +- .../api/entities/datacontract/assertion.py | 6 +- .../datacontract/assertion_operator.py | 23 +- .../datacontract/data_quality_assertion.py | 22 +- .../api/entities/datacontract/datacontract.py | 37 +- .../datacontract/freshness_assertion.py | 18 +- .../entities/datacontract/schema_assertion.py | 20 +- .../datahub/api/entities/dataset/dataset.py | 181 +-- .../src/datahub/configuration/_config_enum.py | 34 +- .../src/datahub/configuration/common.py | 57 +- .../pydantic_migration_helpers.py | 52 - .../ingestion/glossary/datahub_classifier.py | 11 +- .../source/abs/datalake_profiler_config.py | 2 +- .../ingestion/source/dremio/dremio_source.py | 1304 ++++++++--------- .../ingestion/source/ge_profiling_config.py | 2 +- .../source/s3/datalake_profiler_config.py | 2 +- .../datahub/ingestion/source/sql_queries.py | 4 +- .../source/state/stateful_ingestion_base.py | 3 +- .../assertion/snowflake/compiler.py | 52 +- .../src/datahub/pydantic/compat.py | 58 - .../src/datahub/sdk/search_filters.py | 75 +- .../src/datahub/sql_parsing/_models.py | 9 +- .../datahub/utilities/lossy_collections.py | 14 +- .../integration/dynamodb/test_dynamodb.py | 10 +- .../integration/snowflake/test_snowflake.py | 16 +- .../unit/cli/assertion/dmf_definitions.sql | 2 +- .../tests/unit/sdk_v2/test_search_client.py | 24 +- smoke-test/pyproject.toml | 2 +- 38 files changed, 1114 insertions(+), 1368 deletions(-) delete mode 100644 metadata-ingestion/src/datahub/configuration/pydantic_migration_helpers.py delete mode 100644 metadata-ingestion/src/datahub/pydantic/compat.py diff --git a/metadata-ingestion/setup.cfg b/metadata-ingestion/setup.cfg index 804765439e39e2..a3b552b8efe07b 100644 --- a/metadata-ingestion/setup.cfg +++ b/metadata-ingestion/setup.cfg @@ -1,8 +1,7 @@ [mypy] plugins = ./tests/test_helpers/sqlalchemy_mypy_plugin.py, - pydantic.mypy, - pydantic.v1.mypy + pydantic.mypy exclude = ^(venv/|build/|dist/|examples/transforms/setup.py) ignore_missing_imports = yes namespace_packages = no @@ -75,8 +74,6 @@ filterwarnings = ignore:Deprecated call to \`pkg_resources.declare_namespace:DeprecationWarning ignore:pkg_resources is deprecated as an API:DeprecationWarning ignore:Did not recognize type:sqlalchemy.exc.SAWarning - # TODO: We should remove this and start fixing the deprecations. - ignore::pydantic.warnings.PydanticDeprecatedSince20 ignore::datahub.configuration.common.ConfigurationWarning ignore:The new datahub SDK:datahub.errors.ExperimentalWarning # We should not be unexpectedly seeing API tracing warnings. diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py index e5a4c5ee602de4..9a934be91e9e77 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py @@ -1,12 +1,15 @@ from abc import abstractmethod from typing import Optional +from pydantic import BaseModel, Field + from datahub.api.entities.assertion.assertion_trigger import AssertionTrigger -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel, v1_Field from datahub.metadata.com.linkedin.pegasus2avro.assertion import AssertionInfo -class BaseAssertionProtocol(v1_ConfigModel): +class BaseAssertionProtocol(BaseModel): + model_config = {"extra": "forbid"} + @abstractmethod def get_id(self) -> str: pass @@ -24,15 +27,17 @@ def get_assertion_trigger( pass -class BaseAssertion(v1_ConfigModel): - id_raw: Optional[str] = v1_Field( +class BaseAssertion(BaseModel): + model_config = {"extra": "forbid"} + + id_raw: Optional[str] = Field( default=None, description="The raw id of the assertion." "If provided, this is used when creating identifier for this assertion" "along with assertion type and entity.", ) - id: Optional[str] = v1_Field( + id: Optional[str] = Field( default=None, description="The id of the assertion." "If provided, this is used as identifier for this assertion." @@ -41,17 +46,14 @@ class BaseAssertion(v1_ConfigModel): description: Optional[str] = None - # Can contain metadata extracted from datahub. e.g. - # - entity qualified name - # - entity schema meta: Optional[dict] = None class BaseEntityAssertion(BaseAssertion): - entity: str = v1_Field( + entity: str = Field( description="The entity urn that the assertion is associated with" ) - trigger: Optional[AssertionTrigger] = v1_Field( + trigger: Optional[AssertionTrigger] = Field( default=None, description="The trigger schedule for assertion", alias="schedule" ) diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_config_spec.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_config_spec.py index 08205cc621253f..d12d79a9fbbd91 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_config_spec.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_config_spec.py @@ -1,13 +1,13 @@ from typing import List, Optional +from pydantic import BaseModel, Field from ruamel.yaml import YAML from typing_extensions import Literal from datahub.api.entities.assertion.datahub_assertion import DataHubAssertion -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel, v1_Field -class AssertionsConfigSpec(v1_ConfigModel): +class AssertionsConfigSpec(BaseModel): """ Declarative configuration specification for datahub assertions. @@ -18,9 +18,11 @@ class AssertionsConfigSpec(v1_ConfigModel): In future, this would invoke datahub GraphQL API to upsert assertions. """ + model_config = {"extra": "forbid"} + version: Literal[1] - id: Optional[str] = v1_Field( + id: Optional[str] = Field( default=None, alias="namespace", description="Unique identifier of assertions configuration file", @@ -34,8 +36,7 @@ def from_yaml( file: str, ) -> "AssertionsConfigSpec": with open(file) as fp: - yaml = YAML(typ="rt") # default, if not specfied, is 'rt' (round-trip) + yaml = YAML(typ="rt") orig_dictionary = yaml.load(fp) - parsed_spec = AssertionsConfigSpec.parse_obj(orig_dictionary) - # parsed_spec._original_yaml_dict = orig_dictionary + parsed_spec = AssertionsConfigSpec.model_validate(orig_dictionary) return parsed_spec diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py index a05386798495de..7e1259a22c0d6f 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py @@ -1,9 +1,9 @@ import json from typing import List, Optional, Union +from pydantic import BaseModel from typing_extensions import Literal, Protocol -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel from datahub.metadata.schema_classes import ( AssertionStdOperatorClass, AssertionStdParameterClass, @@ -61,7 +61,9 @@ def _generate_assertion_std_parameters( ) -class EqualToOperator(v1_ConfigModel): +class EqualToOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["equal_to"] value: Union[str, int, float] @@ -74,7 +76,8 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class NotEqualToOperator(v1_ConfigModel): +class NotEqualToOperator(BaseModel): + model_config = {"extra": "forbid"} type: Literal["not_equal_to"] value: Union[str, int, float] @@ -87,7 +90,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class BetweenOperator(v1_ConfigModel): +class BetweenOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["between"] min: Union[int, float] max: Union[int, float] @@ -103,7 +108,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: ) -class LessThanOperator(v1_ConfigModel): +class LessThanOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["less_than"] value: Union[int, float] @@ -116,7 +123,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class GreaterThanOperator(v1_ConfigModel): +class GreaterThanOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["greater_than"] value: Union[int, float] @@ -129,7 +138,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class LessThanOrEqualToOperator(v1_ConfigModel): +class LessThanOrEqualToOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["less_than_or_equal_to"] value: Union[int, float] @@ -142,7 +153,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class GreaterThanOrEqualToOperator(v1_ConfigModel): +class GreaterThanOrEqualToOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["greater_than_or_equal_to"] value: Union[int, float] @@ -155,7 +168,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class InOperator(v1_ConfigModel): +class InOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["in"] value: List[Union[str, float, int]] @@ -168,7 +183,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class NotInOperator(v1_ConfigModel): +class NotInOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["not_in"] value: List[Union[str, float, int]] @@ -181,7 +198,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class IsNullOperator(v1_ConfigModel): +class IsNullOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["is_null"] operator: str = AssertionStdOperatorClass.NULL @@ -193,7 +212,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters() -class NotNullOperator(v1_ConfigModel): +class NotNullOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["is_not_null"] operator: str = AssertionStdOperatorClass.NOT_NULL @@ -205,7 +226,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters() -class IsTrueOperator(v1_ConfigModel): +class IsTrueOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["is_true"] operator: str = AssertionStdOperatorClass.IS_TRUE @@ -217,7 +240,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters() -class IsFalseOperator(v1_ConfigModel): +class IsFalseOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["is_false"] operator: str = AssertionStdOperatorClass.IS_FALSE @@ -229,7 +254,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters() -class ContainsOperator(v1_ConfigModel): +class ContainsOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["contains"] value: str @@ -242,7 +269,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class EndsWithOperator(v1_ConfigModel): +class EndsWithOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["ends_with"] value: str @@ -255,7 +284,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class StartsWithOperator(v1_ConfigModel): +class StartsWithOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["starts_with"] value: str @@ -268,7 +299,9 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class MatchesRegexOperator(v1_ConfigModel): +class MatchesRegexOperator(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["matches_regex"] value: str diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py index d7809164847447..5542cdab5f9dee 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py @@ -2,31 +2,31 @@ from typing import Union import humanfriendly +from pydantic import BaseModel, Field, RootModel, field_validator from typing_extensions import Literal -from datahub.configuration.pydantic_migration_helpers import ( - v1_ConfigModel, - v1_Field, - v1_validator, -) +class CronTrigger(BaseModel): + model_config = {"extra": "forbid"} -class CronTrigger(v1_ConfigModel): type: Literal["cron"] - cron: str = v1_Field( + cron: str = Field( description="The cron expression to use. See https://crontab.guru/ for help." ) - timezone: str = v1_Field( + timezone: str = Field( "UTC", description="The timezone to use for the cron schedule. Defaults to UTC.", ) -class IntervalTrigger(v1_ConfigModel): +class IntervalTrigger(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["interval"] interval: timedelta - @v1_validator("interval", pre=True) + @field_validator("interval", mode="before") + @classmethod def lookback_interval_to_timedelta(cls, v): if isinstance(v, str): seconds = humanfriendly.parse_timespan(v) @@ -34,19 +34,25 @@ def lookback_interval_to_timedelta(cls, v): raise ValueError("Invalid value.") -class EntityChangeTrigger(v1_ConfigModel): +class EntityChangeTrigger(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["on_table_change"] -class ManualTrigger(v1_ConfigModel): +class ManualTrigger(BaseModel): + model_config = {"extra": "forbid"} + type: Literal["manual"] -class AssertionTrigger(v1_ConfigModel): - __root__: Union[ - CronTrigger, IntervalTrigger, EntityChangeTrigger, ManualTrigger - ] = v1_Field(discriminator="type") +class AssertionTrigger( + RootModel[Union[CronTrigger, IntervalTrigger, EntityChangeTrigger, ManualTrigger]] +): + root: Union[CronTrigger, IntervalTrigger, EntityChangeTrigger, ManualTrigger] = ( + Field(discriminator="type") + ) @property def trigger(self): - return self.__root__ + return self.root diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/datahub_assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/datahub_assertion.py index ed18b78418d768..55c45eacbd23f3 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/datahub_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/datahub_assertion.py @@ -1,35 +1,31 @@ -from typing import Optional, Union +from typing import Union -from datahub.api.entities.assertion.assertion import BaseAssertionProtocol -from datahub.api.entities.assertion.assertion_trigger import AssertionTrigger -from datahub.api.entities.assertion.field_assertion import FieldAssertion -from datahub.api.entities.assertion.freshness_assertion import FreshnessAssertion -from datahub.api.entities.assertion.sql_assertion import SQLAssertion -from datahub.api.entities.assertion.volume_assertion import VolumeAssertion -from datahub.configuration.pydantic_migration_helpers import v1_Field -from datahub.metadata.com.linkedin.pegasus2avro.assertion import AssertionInfo +from datahub.api.entities.assertion.field_assertion import ( + FieldMetricAssertion, + FieldValuesAssertion, +) +from datahub.api.entities.assertion.freshness_assertion import ( + CronFreshnessAssertion, + FixedIntervalFreshnessAssertion, +) +from datahub.api.entities.assertion.sql_assertion import ( + SqlMetricAssertion, + SqlMetricChangeAssertion, +) +from datahub.api.entities.assertion.volume_assertion import ( + RowCountChangeVolumeAssertion, + RowCountTotalVolumeAssertion, +) - -class DataHubAssertion(BaseAssertionProtocol): - __root__: Union[ - FreshnessAssertion, - VolumeAssertion, - SQLAssertion, - FieldAssertion, - # TODO: Add SchemaAssertion - ] = v1_Field(discriminator="type") - - @property - def assertion(self): - return self.__root__.assertion - - def get_assertion_info_aspect( - self, - ) -> AssertionInfo: - return self.__root__.get_assertion_info_aspect() - - def get_id(self) -> str: - return self.__root__.get_id() - - def get_assertion_trigger(self) -> Optional[AssertionTrigger]: - return self.__root__.get_assertion_trigger() +# Pydantic v2 smart union: automatically discriminates based on the 'type' field +# (eg freshness/volume/sql/field) and unique fields within each type +DataHubAssertion = Union[ + FixedIntervalFreshnessAssertion, + CronFreshnessAssertion, + RowCountTotalVolumeAssertion, + RowCountChangeVolumeAssertion, + SqlMetricAssertion, + SqlMetricChangeAssertion, + FieldMetricAssertion, + FieldValuesAssertion, +] diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py index ae062c3a8e5cbd..200d0149bd1dc1 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py @@ -1,17 +1,16 @@ from enum import Enum from typing import Optional, Union +from pydantic import BaseModel, Field from typing_extensions import Literal from datahub.api.entities.assertion.assertion import ( - BaseAssertionProtocol, BaseEntityAssertion, ) from datahub.api.entities.assertion.assertion_operator import Operators from datahub.api.entities.assertion.assertion_trigger import AssertionTrigger from datahub.api.entities.assertion.field_metric import FieldMetric from datahub.api.entities.assertion.filter import DatasetFilter -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel, v1_Field from datahub.emitter.mce_builder import datahub_guid from datahub.metadata.com.linkedin.pegasus2avro.assertion import ( AssertionInfo, @@ -30,9 +29,11 @@ ) -class FieldValuesFailThreshold(v1_ConfigModel): - type: Literal["count", "percentage"] = v1_Field(default="count") - value: int = v1_Field(default=0) +class FieldValuesFailThreshold(BaseModel): + model_config = {"extra": "forbid"} + + type: Literal["count", "percentage"] = Field(default="count") + value: int = Field(default=0) def to_field_values_failure_threshold(self) -> FieldValuesFailThresholdClass: return FieldValuesFailThresholdClass( @@ -52,13 +53,13 @@ class FieldTransform(Enum): class FieldValuesAssertion(BaseEntityAssertion): type: Literal["field"] field: str - field_transform: Optional[FieldTransform] = v1_Field(default=None) - operator: Operators = v1_Field(discriminator="type", alias="condition") - filters: Optional[DatasetFilter] = v1_Field(default=None) - failure_threshold: FieldValuesFailThreshold = v1_Field( + field_transform: Optional[FieldTransform] = Field(default=None) + operator: Operators = Field(discriminator="type", alias="condition") + filters: Optional[DatasetFilter] = Field(default=None) + failure_threshold: FieldValuesFailThreshold = Field( default=FieldValuesFailThreshold() ) - exclude_nulls: bool = v1_Field(default=True) + exclude_nulls: bool = Field(default=True) def get_assertion_info( self, @@ -98,13 +99,19 @@ def get_id(self) -> str: } return self.id or datahub_guid(guid_dict) + def get_assertion_info_aspect(self) -> AssertionInfo: + return self.get_assertion_info() + + def get_assertion_trigger(self) -> Optional[AssertionTrigger]: + return self.trigger + class FieldMetricAssertion(BaseEntityAssertion): type: Literal["field"] field: str - operator: Operators = v1_Field(discriminator="type", alias="condition") + operator: Operators = Field(discriminator="type", alias="condition") metric: FieldMetric - filters: Optional[DatasetFilter] = v1_Field(default=None) + filters: Optional[DatasetFilter] = Field(default=None) def get_assertion_info( self, @@ -138,21 +145,13 @@ def get_id(self) -> str: } return self.id or datahub_guid(guid_dict) + def get_assertion_info_aspect(self) -> AssertionInfo: + return self.get_assertion_info() -class FieldAssertion(BaseAssertionProtocol): - __root__: Union[FieldMetricAssertion, FieldValuesAssertion] - - @property - def assertion(self): - return self.__root__ - - def get_id(self) -> str: - return self.__root__.get_id() + def get_assertion_trigger(self) -> Optional[AssertionTrigger]: + return self.trigger - def get_assertion_info_aspect( - self, - ) -> AssertionInfo: - return self.__root__.get_assertion_info() - def get_assertion_trigger(self) -> Optional[AssertionTrigger]: - return self.__root__.trigger +# Pydantic v2 smart union: automatically discriminates based on presence of +# unique fields (eg metric field vs operator+failure_threshold combination) +FieldAssertion = Union[FieldMetricAssertion, FieldValuesAssertion] diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/filter.py b/metadata-ingestion/src/datahub/api/entities/assertion/filter.py index 05d75b674d6af9..faffc140eee5b4 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/filter.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/filter.py @@ -1,13 +1,12 @@ +from pydantic import BaseModel from typing_extensions import Literal -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel +class SqlFilter(BaseModel): + model_config = {"extra": "forbid"} -class SqlFilter(v1_ConfigModel): type: Literal["sql"] sql: str DatasetFilter = SqlFilter -# class DatasetFilter(v1_ConfigModel): -# __root__: Union[SqlFilter] = v1_Field(discriminator="type") diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/freshness_assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/freshness_assertion.py index f9e1df7d68f271..b954034bb57293 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/freshness_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/freshness_assertion.py @@ -3,15 +3,14 @@ from typing import Optional, Union import humanfriendly +from pydantic import Field, field_validator from typing_extensions import Literal from datahub.api.entities.assertion.assertion import ( - BaseAssertionProtocol, BaseEntityAssertion, ) from datahub.api.entities.assertion.assertion_trigger import AssertionTrigger from datahub.api.entities.assertion.filter import DatasetFilter -from datahub.configuration.pydantic_migration_helpers import v1_Field, v1_validator from datahub.emitter.mce_builder import datahub_guid from datahub.metadata.com.linkedin.pegasus2avro.assertion import ( AssertionInfo, @@ -32,19 +31,18 @@ class FreshnessSourceType(Enum): class CronFreshnessAssertion(BaseEntityAssertion): type: Literal["freshness"] - freshness_type: Literal["cron"] - cron: str = v1_Field( + cron: str = Field( description="The cron expression to use. See https://crontab.guru/ for help." ) - timezone: str = v1_Field( + timezone: str = Field( "UTC", description="The timezone to use for the cron schedule. Defaults to UTC.", ) - source_type: FreshnessSourceType = v1_Field( + source_type: FreshnessSourceType = Field( default=FreshnessSourceType.LAST_MODIFIED_COLUMN ) last_modified_field: str - filters: Optional[DatasetFilter] = v1_Field(default=None) + filters: Optional[DatasetFilter] = Field(default=None) def get_assertion_info( self, @@ -62,18 +60,32 @@ def get_assertion_info( ), ) + def get_id(self) -> str: + guid_dict = { + "entity": self.entity, + "type": self.type, + "id_raw": self.id_raw, + } + return self.id or datahub_guid(guid_dict) + + def get_assertion_info_aspect(self) -> AssertionInfo: + return self.get_assertion_info() + + def get_assertion_trigger(self) -> Optional[AssertionTrigger]: + return self.trigger + class FixedIntervalFreshnessAssertion(BaseEntityAssertion): type: Literal["freshness"] - freshness_type: Literal["interval"] = v1_Field(default="interval") lookback_interval: timedelta - filters: Optional[DatasetFilter] = v1_Field(default=None) - source_type: FreshnessSourceType = v1_Field( + filters: Optional[DatasetFilter] = Field(default=None) + source_type: FreshnessSourceType = Field( default=FreshnessSourceType.LAST_MODIFIED_COLUMN ) last_modified_field: str - @v1_validator("lookback_interval", pre=True) + @field_validator("lookback_interval", mode="before") + @classmethod def lookback_interval_to_timedelta(cls, v): if isinstance(v, str): seconds = humanfriendly.parse_timespan(v) @@ -99,26 +111,21 @@ def get_assertion_info( ), ) - -class FreshnessAssertion(BaseAssertionProtocol): - __root__: Union[FixedIntervalFreshnessAssertion, CronFreshnessAssertion] - - @property - def assertion(self): - return self.__root__ - def get_id(self) -> str: guid_dict = { - "entity": self.__root__.entity, - "type": self.__root__.type, - "id_raw": self.__root__.id_raw, + "entity": self.entity, + "type": self.type, + "id_raw": self.id_raw, } - return self.__root__.id or datahub_guid(guid_dict) + return self.id or datahub_guid(guid_dict) - def get_assertion_info_aspect( - self, - ) -> AssertionInfo: - return self.__root__.get_assertion_info() + def get_assertion_info_aspect(self) -> AssertionInfo: + return self.get_assertion_info() def get_assertion_trigger(self) -> Optional[AssertionTrigger]: - return self.__root__.trigger + return self.trigger + + +# Pydantic v2 smart union: automatically discriminates based on presence of +# unique fields (eg lookback_interval vs cron) +FreshnessAssertion = Union[FixedIntervalFreshnessAssertion, CronFreshnessAssertion] diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/sql_assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/sql_assertion.py index 3d12cfde428f4e..e71731efbb51bd 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/sql_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/sql_assertion.py @@ -1,14 +1,13 @@ from typing import Optional, Union +from pydantic import Field from typing_extensions import Literal from datahub.api.entities.assertion.assertion import ( - BaseAssertionProtocol, BaseEntityAssertion, ) from datahub.api.entities.assertion.assertion_operator import Operators from datahub.api.entities.assertion.assertion_trigger import AssertionTrigger -from datahub.configuration.pydantic_migration_helpers import v1_Field from datahub.emitter.mce_builder import datahub_guid from datahub.metadata.com.linkedin.pegasus2avro.assertion import ( AssertionInfo, @@ -22,7 +21,7 @@ class SqlMetricAssertion(BaseEntityAssertion): type: Literal["sql"] statement: str - operator: Operators = v1_Field(discriminator="type", alias="condition") + operator: Operators = Field(discriminator="type", alias="condition") def get_assertion_info( self, @@ -39,12 +38,26 @@ def get_assertion_info( ), ) + def get_id(self) -> str: + guid_dict = { + "entity": self.entity, + "type": self.type, + "id_raw": self.id_raw, + } + return self.id or datahub_guid(guid_dict) + + def get_assertion_info_aspect(self) -> AssertionInfo: + return self.get_assertion_info() + + def get_assertion_trigger(self) -> Optional[AssertionTrigger]: + return self.trigger + class SqlMetricChangeAssertion(BaseEntityAssertion): type: Literal["sql"] statement: str change_type: Literal["absolute", "percentage"] - operator: Operators = v1_Field(discriminator="type", alias="condition") + operator: Operators = Field(discriminator="type", alias="condition") def get_assertion_info( self, @@ -66,26 +79,21 @@ def get_assertion_info( ), ) - -class SQLAssertion(BaseAssertionProtocol): - __root__: Union[SqlMetricAssertion, SqlMetricChangeAssertion] = v1_Field() - - @property - def assertion(self): - return self.__root__ - def get_id(self) -> str: guid_dict = { - "entity": self.__root__.entity, - "type": self.__root__.type, - "id_raw": self.__root__.id_raw, + "entity": self.entity, + "type": self.type, + "id_raw": self.id_raw, } - return self.__root__.id or datahub_guid(guid_dict) + return self.id or datahub_guid(guid_dict) - def get_assertion_info_aspect( - self, - ) -> AssertionInfo: - return self.__root__.get_assertion_info() + def get_assertion_info_aspect(self) -> AssertionInfo: + return self.get_assertion_info() def get_assertion_trigger(self) -> Optional[AssertionTrigger]: - return self.__root__.trigger + return self.trigger + + +# Pydantic v2 smart union: automatically discriminates based on presence of +# unique fields (eg absence vs presence of change_type) +SQLAssertion = Union[SqlMetricAssertion, SqlMetricChangeAssertion] diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/volume_assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/volume_assertion.py index da6a125874aa72..e7062c3ef10ec4 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/volume_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/volume_assertion.py @@ -1,15 +1,14 @@ from typing import Optional, Union +from pydantic import Field from typing_extensions import Literal from datahub.api.entities.assertion.assertion import ( - BaseAssertionProtocol, BaseEntityAssertion, ) from datahub.api.entities.assertion.assertion_operator import Operators from datahub.api.entities.assertion.assertion_trigger import AssertionTrigger from datahub.api.entities.assertion.filter import DatasetFilter -from datahub.configuration.pydantic_migration_helpers import v1_Field from datahub.emitter.mce_builder import datahub_guid from datahub.metadata.com.linkedin.pegasus2avro.assertion import ( AssertionInfo, @@ -24,9 +23,9 @@ class RowCountTotalVolumeAssertion(BaseEntityAssertion): type: Literal["volume"] - metric: Literal["row_count"] = v1_Field(default="row_count") - operator: Operators = v1_Field(discriminator="type", alias="condition") - filters: Optional[DatasetFilter] = v1_Field(default=None) + metric: Literal["row_count"] = Field(default="row_count") + operator: Operators = Field(discriminator="type", alias="condition") + filters: Optional[DatasetFilter] = Field(default=None) def get_assertion_info( self, @@ -44,13 +43,27 @@ def get_assertion_info( ), ) + def get_id(self) -> str: + guid_dict = { + "entity": self.entity, + "type": self.type, + "id_raw": self.id_raw, + } + return self.id or datahub_guid(guid_dict) + + def get_assertion_info_aspect(self) -> AssertionInfo: + return self.get_assertion_info() + + def get_assertion_trigger(self) -> Optional[AssertionTrigger]: + return self.trigger + class RowCountChangeVolumeAssertion(BaseEntityAssertion): type: Literal["volume"] - metric: Literal["row_count"] = v1_Field(default="row_count") + metric: Literal["row_count"] = Field(default="row_count") change_type: Literal["absolute", "percentage"] - operator: Operators = v1_Field(discriminator="type", alias="condition") - filters: Optional[DatasetFilter] = v1_Field(default=None) + operator: Operators = Field(discriminator="type", alias="condition") + filters: Optional[DatasetFilter] = Field(default=None) def get_assertion_info( self, @@ -73,26 +86,21 @@ def get_assertion_info( ), ) - -class VolumeAssertion(BaseAssertionProtocol): - __root__: Union[RowCountTotalVolumeAssertion, RowCountChangeVolumeAssertion] - - @property - def assertion(self): - return self.__root__ - def get_id(self) -> str: guid_dict = { - "entity": self.__root__.entity, - "type": self.__root__.type, - "id_raw": self.__root__.id_raw, + "entity": self.entity, + "type": self.type, + "id_raw": self.id_raw, } - return self.__root__.id or datahub_guid(guid_dict) + return self.id or datahub_guid(guid_dict) - def get_assertion_info_aspect( - self, - ) -> AssertionInfo: - return self.__root__.get_assertion_info() + def get_assertion_info_aspect(self) -> AssertionInfo: + return self.get_assertion_info() def get_assertion_trigger(self) -> Optional[AssertionTrigger]: - return self.__root__.trigger + return self.trigger + + +# Pydantic v2 smart union: automatically discriminates based on presence of +# unique fields (eg absence vs presence of change_type) +VolumeAssertion = Union[RowCountTotalVolumeAssertion, RowCountChangeVolumeAssertion] diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion.py b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion.py index 89ac528efe81a1..d135b63389c217 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion.py @@ -1,7 +1,9 @@ from typing import Optional -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel +from pydantic import BaseModel -class BaseAssertion(v1_ConfigModel): +class BaseAssertion(BaseModel): + model_config = {"extra": "forbid"} + description: Optional[str] = None diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py index 145a6097d7336c..3438f7242de24b 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py @@ -1,8 +1,8 @@ from typing import Optional, Union +from pydantic import BaseModel from typing_extensions import Literal, Protocol -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel from datahub.metadata.schema_classes import ( AssertionStdOperatorClass, AssertionStdParameterClass, @@ -56,7 +56,8 @@ def _generate_assertion_std_parameters( ) -class EqualToOperator(v1_ConfigModel): +class EqualToOperator(BaseModel): + model_config = {"extra": "forbid"} type: Literal["equal_to"] value: Union[str, int, float] @@ -69,7 +70,8 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class BetweenOperator(v1_ConfigModel): +class BetweenOperator(BaseModel): + model_config = {"extra": "forbid"} type: Literal["between"] min: Union[int, float] max: Union[int, float] @@ -85,7 +87,8 @@ def generate_parameters(self) -> AssertionStdParametersClass: ) -class LessThanOperator(v1_ConfigModel): +class LessThanOperator(BaseModel): + model_config = {"extra": "forbid"} type: Literal["less_than"] value: Union[int, float] @@ -98,7 +101,8 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class GreaterThanOperator(v1_ConfigModel): +class GreaterThanOperator(BaseModel): + model_config = {"extra": "forbid"} type: Literal["greater_than"] value: Union[int, float] @@ -111,7 +115,8 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class LessThanOrEqualToOperator(v1_ConfigModel): +class LessThanOrEqualToOperator(BaseModel): + model_config = {"extra": "forbid"} type: Literal["less_than_or_equal_to"] value: Union[int, float] @@ -124,7 +129,8 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class GreaterThanOrEqualToOperator(v1_ConfigModel): +class GreaterThanOrEqualToOperator(BaseModel): + model_config = {"extra": "forbid"} type: Literal["greater_than_or_equal_to"] value: Union[int, float] @@ -137,7 +143,8 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class NotNullOperator(v1_ConfigModel): +class NotNullOperator(BaseModel): + model_config = {"extra": "forbid"} type: Literal["not_null"] operator: str = AssertionStdOperatorClass.NOT_NULL diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/data_quality_assertion.py b/metadata-ingestion/src/datahub/api/entities/datacontract/data_quality_assertion.py index 975aa359bd2031..2019238aba6ab4 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/data_quality_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/data_quality_assertion.py @@ -1,11 +1,11 @@ from typing import List, Optional, Union +from pydantic import Field, RootModel from typing_extensions import Literal import datahub.emitter.mce_builder as builder from datahub.api.entities.datacontract.assertion import BaseAssertion from datahub.api.entities.datacontract.assertion_operator import Operators -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel, v1_Field from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.metadata.schema_classes import ( AssertionInfoClass, @@ -24,7 +24,7 @@ class IdConfigMixin(BaseAssertion): - id_raw: Optional[str] = v1_Field( + id_raw: Optional[str] = Field( default=None, alias="id", description="The id of the assertion. If not provided, one will be generated using the type.", @@ -37,7 +37,7 @@ def generate_default_id(self) -> str: class CustomSQLAssertion(IdConfigMixin, BaseAssertion): type: Literal["custom_sql"] sql: str - operator: Operators = v1_Field(discriminator="type") + operator: Operators = Field(discriminator="type") def generate_default_id(self) -> str: return f"{self.type}-{self.sql}-{self.operator.id()}" @@ -88,20 +88,20 @@ def generate_assertion_info(self, entity_urn: str) -> AssertionInfoClass: ) -class DataQualityAssertion(v1_ConfigModel): - __root__: Union[ +class DataQualityAssertion(RootModel[Union[CustomSQLAssertion, ColumnUniqueAssertion]]): + root: Union[ CustomSQLAssertion, ColumnUniqueAssertion, - ] = v1_Field(discriminator="type") + ] = Field(discriminator="type") @property def id(self) -> str: - if self.__root__.id_raw: - return self.__root__.id_raw + if self.root.id_raw: + return self.root.id_raw try: - return self.__root__.generate_default_id() + return self.root.generate_default_id() except NotImplementedError: - return self.__root__.type + return self.root.type def generate_mcp( self, assertion_urn: str, entity_urn: str @@ -109,6 +109,6 @@ def generate_mcp( return [ MetadataChangeProposalWrapper( entityUrn=assertion_urn, - aspect=self.__root__.generate_assertion_info(entity_urn), + aspect=self.root.generate_assertion_info(entity_urn), ) ] diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py b/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py index 0f8c7ceafe13bb..61a93482b3a1ba 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py @@ -1,6 +1,7 @@ import collections from typing import Dict, Iterable, List, Optional, Tuple, Union +from pydantic import BaseModel, Field, field_validator from ruamel.yaml import YAML from typing_extensions import Literal @@ -10,11 +11,6 @@ ) from datahub.api.entities.datacontract.freshness_assertion import FreshnessAssertion from datahub.api.entities.datacontract.schema_assertion import SchemaAssertion -from datahub.configuration.pydantic_migration_helpers import ( - v1_ConfigModel, - v1_Field, - v1_validator, -) from datahub.emitter.mce_builder import datahub_guid, make_assertion_urn from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.metadata.schema_classes import ( @@ -31,7 +27,7 @@ from datahub.utilities.urns.urn import guess_entity_type -class DataContract(v1_ConfigModel): +class DataContract(BaseModel): """A yml representation of a Data Contract. This model is used as a simpler, Python-native representation of a DataHub data contract. @@ -39,38 +35,37 @@ class DataContract(v1_ConfigModel): that can be emitted to DataHub. """ + model_config = {"extra": "forbid"} + version: Literal[1] - id: Optional[str] = v1_Field( + id: Optional[str] = Field( default=None, alias="urn", description="The data contract urn. If not provided, one will be generated.", ) - entity: str = v1_Field( + entity: str = Field( description="The entity urn that the Data Contract is associated with" ) - properties: Optional[Dict[str, Union[str, float, List[Union[str, float]]]]] = ( - v1_Field( - default=None, - description="Structured properties associated with the data contract.", - ) + properties: Optional[Dict[str, Union[str, float, List[Union[str, float]]]]] = Field( + default=None, + description="Structured properties associated with the data contract.", ) - schema_field: Optional[SchemaAssertion] = v1_Field(default=None, alias="schema") + schema_field: Optional[SchemaAssertion] = Field(default=None, alias="schema") - freshness: Optional[FreshnessAssertion] = v1_Field(default=None) + freshness: Optional[FreshnessAssertion] = Field(default=None) - # TODO: Add a validator to ensure that ids are unique - data_quality: Optional[List[DataQualityAssertion]] = v1_Field(default=None) + data_quality: Optional[List[DataQualityAssertion]] = Field(default=None) _original_yaml_dict: Optional[dict] = None - @v1_validator("data_quality") # type: ignore + @field_validator("data_quality") + @classmethod def validate_data_quality( cls, data_quality: Optional[List[DataQualityAssertion]] ) -> Optional[List[DataQualityAssertion]]: if data_quality: - # Raise an error if there are duplicate ids. id_counts = collections.Counter(dq_check.id for dq_check in data_quality) duplicates = [id for id, count in id_counts.items() if count > 1] @@ -243,8 +238,8 @@ def from_yaml( file: str, ) -> "DataContract": with open(file) as fp: - yaml = YAML(typ="rt") # default, if not specfied, is 'rt' (round-trip) + yaml = YAML(typ="rt") orig_dictionary = yaml.load(fp) - parsed_data_contract = DataContract.parse_obj(orig_dictionary) + parsed_data_contract = DataContract.model_validate(orig_dictionary) parsed_data_contract._original_yaml_dict = orig_dictionary return parsed_data_contract diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/freshness_assertion.py b/metadata-ingestion/src/datahub/api/entities/datacontract/freshness_assertion.py index 86942766889676..961b1c0bce553e 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/freshness_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/freshness_assertion.py @@ -3,10 +3,10 @@ from datetime import timedelta from typing import List, Union +from pydantic import Field, RootModel from typing_extensions import Literal from datahub.api.entities.datacontract.assertion import BaseAssertion -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel, v1_Field from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.metadata.schema_classes import ( AssertionInfoClass, @@ -24,10 +24,10 @@ class CronFreshnessAssertion(BaseAssertion): type: Literal["cron"] - cron: str = v1_Field( + cron: str = Field( description="The cron expression to use. See https://crontab.guru/ for help." ) - timezone: str = v1_Field( + timezone: str = Field( "UTC", description="The timezone to use for the cron schedule. Defaults to UTC.", ) @@ -57,14 +57,16 @@ def generate_freshness_assertion_schedule(self) -> FreshnessAssertionScheduleCla ) -class FreshnessAssertion(v1_ConfigModel): - __root__: Union[CronFreshnessAssertion, FixedIntervalFreshnessAssertion] = v1_Field( +class FreshnessAssertion( + RootModel[Union[CronFreshnessAssertion, FixedIntervalFreshnessAssertion]] +): + root: Union[CronFreshnessAssertion, FixedIntervalFreshnessAssertion] = Field( discriminator="type" ) @property def id(self): - return self.__root__.type + return self.root.type def generate_mcp( self, assertion_urn: str, entity_urn: str @@ -74,8 +76,8 @@ def generate_mcp( freshnessAssertion=FreshnessAssertionInfoClass( entity=entity_urn, type=FreshnessAssertionTypeClass.DATASET_CHANGE, - schedule=self.__root__.generate_freshness_assertion_schedule(), + schedule=self.root.generate_freshness_assertion_schedule(), ), - description=self.__root__.description, + description=self.root.description, ) return [MetadataChangeProposalWrapper(entityUrn=assertion_urn, aspect=aspect)] diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py b/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py index 39297d1a98d026..878e3dbc07e343 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py @@ -3,10 +3,10 @@ import json from typing import List, Union +from pydantic import Field, RootModel from typing_extensions import Literal from datahub.api.entities.datacontract.assertion import BaseAssertion -from datahub.configuration.pydantic_migration_helpers import v1_ConfigModel, v1_Field from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.extractor.json_schema_util import get_schema_metadata from datahub.metadata.schema_classes import ( @@ -22,12 +22,11 @@ class JsonSchemaContract(BaseAssertion): type: Literal["json-schema"] - json_schema: dict = v1_Field(alias="json-schema") + json_schema: dict = Field(alias="json-schema") _schema_metadata: SchemaMetadataClass - def _init_private_attributes(self) -> None: - super()._init_private_attributes() + def model_post_init(self, __context: object) -> None: self._schema_metadata = get_schema_metadata( platform="urn:li:dataPlatform:datahub", name="", @@ -46,8 +45,7 @@ class Config: _schema_metadata: SchemaMetadataClass - def _init_private_attributes(self) -> None: - super()._init_private_attributes() + def model_post_init(self, __context: object) -> None: self._schema_metadata = SchemaMetadataClass( schemaName="", platform="urn:li:dataPlatform:datahub", @@ -58,14 +56,14 @@ def _init_private_attributes(self) -> None: ) -class SchemaAssertion(v1_ConfigModel): - __root__: Union[JsonSchemaContract, FieldListSchemaContract] = v1_Field( +class SchemaAssertion(RootModel[Union[JsonSchemaContract, FieldListSchemaContract]]): + root: Union[JsonSchemaContract, FieldListSchemaContract] = Field( discriminator="type" ) @property def id(self): - return self.__root__.type + return self.root.type def generate_mcp( self, assertion_urn: str, entity_urn: str @@ -74,9 +72,9 @@ def generate_mcp( type=AssertionTypeClass.DATA_SCHEMA, schemaAssertion=SchemaAssertionInfoClass( entity=entity_urn, - schema=self.__root__._schema_metadata, + schema=self.root._schema_metadata, ), - description=self.__root__.description, + description=self.root.description, ) return [MetadataChangeProposalWrapper(entityUrn=assertion_urn, aspect=aspect)] diff --git a/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py b/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py index c3f0b75ad8eae0..5bc70454e49852 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py +++ b/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py @@ -70,9 +70,6 @@ StructuredPropertyUrn, TagUrn, ) -from datahub.pydantic.compat import ( - PYDANTIC_VERSION, -) from datahub.specific.dataset import DatasetPatchBuilder from datahub.utilities.urns.dataset_urn import DatasetUrn @@ -81,22 +78,12 @@ class StrictModel(BaseModel): - """ - Base model with strict validation. - Compatible with both Pydantic v1 and v2. - """ + """Base model with strict validation.""" - if PYDANTIC_VERSION >= 2: - # Pydantic v2 config - model_config = { - "validate_assignment": True, - "extra": "forbid", - } - else: - # Pydantic v1 config - class Config: - validate_assignment = True - extra = "forbid" + model_config = { + "validate_assignment": True, + "extra": "forbid", + } # Define type aliases for the complex types @@ -292,58 +279,24 @@ def _from_datahub_type( return "record" raise ValueError(f"Type {input_type} is not a valid primitive type") - if PYDANTIC_VERSION < 2: - - def dict(self, **kwargs): - """Custom dict method for Pydantic v1 to handle YAML serialization properly.""" - exclude = kwargs.pop("exclude", None) or set() + def model_dump(self, **kwargs): + """Custom model_dump method for Pydantic v2 to handle YAML serialization properly.""" + exclude = kwargs.pop("exclude", None) or set() - # if nativeDataType and type are identical, exclude nativeDataType from the output - if self.nativeDataType == self.type and self.nativeDataType is not None: - exclude.add("nativeDataType") + if self.nativeDataType == self.type and self.nativeDataType is not None: + exclude.add("nativeDataType") - # if the id is the same as the urn's fieldPath, exclude id from the output + if self.urn: + field_urn = SchemaFieldUrn.from_string(self.urn) + if Dataset._simplify_field_path(field_urn.field_path) == self.id: + exclude.add("urn") - if self.urn: - field_urn = SchemaFieldUrn.from_string(self.urn) - if Dataset._simplify_field_path(field_urn.field_path) == self.id: - exclude.add("urn") - - kwargs.pop("exclude_defaults", None) - - self.structured_properties = ( - StructuredPropertiesHelper.simplify_structured_properties_list( - self.structured_properties - ) - ) - - return super().dict(exclude=exclude, exclude_defaults=True, **kwargs) - - else: - # For v2, implement model_dump with similar logic as dict - def model_dump(self, **kwargs): - """Custom model_dump method for Pydantic v2 to handle YAML serialization properly.""" - exclude = kwargs.pop("exclude", None) or set() - - # if nativeDataType and type are identical, exclude nativeDataType from the output - if self.nativeDataType == self.type and self.nativeDataType is not None: - exclude.add("nativeDataType") - - # if the id is the same as the urn's fieldPath, exclude id from the output - if self.urn: - field_urn = SchemaFieldUrn.from_string(self.urn) - if Dataset._simplify_field_path(field_urn.field_path) == self.id: - exclude.add("urn") - - self.structured_properties = ( - StructuredPropertiesHelper.simplify_structured_properties_list( - self.structured_properties - ) + self.structured_properties = ( + StructuredPropertiesHelper.simplify_structured_properties_list( + self.structured_properties ) - if hasattr(super(), "model_dump"): - return super().model_dump( # type: ignore - exclude=exclude, exclude_defaults=True, **kwargs - ) + ) + return super().model_dump(exclude=exclude, exclude_defaults=True, **kwargs) class SchemaSpecification(BaseModel): @@ -954,80 +907,33 @@ def from_datahub( downstreams=downstreams if config.include_downstreams else None, ) - if PYDANTIC_VERSION < 2: - - def dict(self, **kwargs): - """Custom dict method for Pydantic v1 to handle YAML serialization properly.""" - exclude = kwargs.pop("exclude", set()) + def model_dump(self, **kwargs): + """Custom model_dump method for Pydantic v2 to handle YAML serialization properly.""" + exclude = kwargs.pop("exclude", None) or set() - # If id and name are identical, exclude name from the output - if self.id == self.name and self.id is not None: - exclude.add("name") + if self.id == self.name and self.id is not None: + exclude.add("name") - # if subtype and subtypes are identical or subtypes is a singleton list, exclude subtypes from the output - if self.subtypes and len(self.subtypes) == 1: - self.subtype = self.subtypes[0] - exclude.add("subtypes") + if self.subtypes and len(self.subtypes) == 1: + self.subtype = self.subtypes[0] + exclude.add("subtypes") - result = super().dict(exclude=exclude, **kwargs) + result = super().model_dump(exclude=exclude, **kwargs) - # Custom handling for schema_metadata/schema - if self.schema_metadata and "schema" in result: - schema_data = result["schema"] + if self.schema_metadata and "schema" in result: + schema_data = result["schema"] - # Handle fields if they exist - if "fields" in schema_data and isinstance(schema_data["fields"], list): - # Process each field using its custom dict method - processed_fields = [] - if self.schema_metadata and self.schema_metadata.fields: - for field in self.schema_metadata.fields: - if field: - # Use dict method for Pydantic v1 - processed_field = field.dict(**kwargs) - processed_fields.append(processed_field) + if "fields" in schema_data and isinstance(schema_data["fields"], list): + processed_fields = [] + if self.schema_metadata and self.schema_metadata.fields: + for field in self.schema_metadata.fields: + if field: + processed_field = field.model_dump(**kwargs) + processed_fields.append(processed_field) - # Replace the fields in the result with the processed ones - schema_data["fields"] = processed_fields + schema_data["fields"] = processed_fields - return result - else: - - def model_dump(self, **kwargs): - """Custom model_dump method for Pydantic v2 to handle YAML serialization properly.""" - exclude = kwargs.pop("exclude", None) or set() - - # If id and name are identical, exclude name from the output - if self.id == self.name and self.id is not None: - exclude.add("name") - - # if subtype and subtypes are identical or subtypes is a singleton list, exclude subtypes from the output - if self.subtypes and len(self.subtypes) == 1: - self.subtype = self.subtypes[0] - exclude.add("subtypes") - - if hasattr(super(), "model_dump"): - result = super().model_dump(exclude=exclude, **kwargs) # type: ignore - else: - result = super().dict(exclude=exclude, **kwargs) - - # Custom handling for schema_metadata/schema - if self.schema_metadata and "schema" in result: - schema_data = result["schema"] - - # Handle fields if they exist - if "fields" in schema_data and isinstance(schema_data["fields"], list): - # Process each field using its custom model_dump method - processed_fields = [] - if self.schema_metadata and self.schema_metadata.fields: - for field in self.schema_metadata.fields: - if field: - processed_field = field.model_dump(**kwargs) - processed_fields.append(processed_field) - - # Replace the fields in the result with the processed ones - schema_data["fields"] = processed_fields - - return result + return result def to_yaml( self, @@ -1038,14 +944,7 @@ def to_yaml( Preserves comments and structure of the existing YAML file. Returns True if file was written, False if no changes were detected. """ - # Create new model data - # Create new model data - choose dict() or model_dump() based on Pydantic version - if PYDANTIC_VERSION >= 2: - new_data = self.model_dump( - exclude_none=True, exclude_unset=True, by_alias=True - ) - else: - new_data = self.dict(exclude_none=True, exclude_unset=True, by_alias=True) + new_data = self.model_dump(exclude_none=True, exclude_unset=True, by_alias=True) # Set up ruamel.yaml for preserving comments yaml_handler = YAML(typ="rt") # round-trip mode diff --git a/metadata-ingestion/src/datahub/configuration/_config_enum.py b/metadata-ingestion/src/datahub/configuration/_config_enum.py index 190a006b077d9f..4eb8a2467127d6 100644 --- a/metadata-ingestion/src/datahub/configuration/_config_enum.py +++ b/metadata-ingestion/src/datahub/configuration/_config_enum.py @@ -1,41 +1,19 @@ from enum import Enum -import pydantic -import pydantic.types -import pydantic.validators - -from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2 - class ConfigEnum(Enum): - # Ideally we would use @staticmethod here, but some versions of Python don't support it. - # See https://github.com/python/mypy/issues/7591. def _generate_next_value_( # type: ignore name: str, start, count, last_values ) -> str: - # This makes the enum value match the enum option name. - # From https://stackoverflow.com/a/44785241/5004662. return name - if PYDANTIC_VERSION_2: - # if TYPE_CHECKING: - # from pydantic import GetCoreSchemaHandler - - @classmethod - def __get_pydantic_core_schema__(cls, source_type, handler): # type: ignore - from pydantic_core import core_schema - - return core_schema.no_info_before_validator_function( - cls.validate, handler(source_type) - ) - - else: + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): # type: ignore + from pydantic_core import core_schema - @classmethod - def __get_validators__(cls) -> "pydantic.types.CallableGenerator": - # We convert the text to uppercase before attempting to match it to an enum value. - yield cls.validate - yield pydantic.validators.enum_member_validator + return core_schema.no_info_before_validator_function( + cls.validate, handler(source_type) + ) @classmethod def validate(cls, v): # type: ignore[no-untyped-def] diff --git a/metadata-ingestion/src/datahub/configuration/common.py b/metadata-ingestion/src/datahub/configuration/common.py index ad15c327c2cd18..1c3a9127d0a2e6 100644 --- a/metadata-ingestion/src/datahub/configuration/common.py +++ b/metadata-ingestion/src/datahub/configuration/common.py @@ -21,12 +21,11 @@ import pydantic import pydantic_core from cached_property import cached_property -from pydantic import BaseModel, Extra, ValidationError +from pydantic import BaseModel, ValidationError from pydantic.fields import Field from typing_extensions import Protocol, Self from datahub.configuration._config_enum import ConfigEnum as ConfigEnum -from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2 from datahub.utilities.dedup_list import deduplicate_list REDACT_KEYS = { @@ -125,47 +124,28 @@ def _schema_extra(schema: Dict[str, Any], model: Type["ConfigModel"]) -> None: for key in remove_fields: del schema["properties"][key] - # This is purely to suppress pydantic's warnings, since this class is used everywhere. - if PYDANTIC_VERSION_2: - extra = "forbid" - ignored_types = (cached_property,) - json_schema_extra = _schema_extra - else: - extra = Extra.forbid - underscore_attrs_are_private = True - keep_untouched = ( - cached_property, - ) # needed to allow cached_property to work. See https://github.com/samuelcolvin/pydantic/issues/1241 for more info. - schema_extra = _schema_extra + extra = "forbid" + ignored_types = (cached_property,) + json_schema_extra = _schema_extra @classmethod def parse_obj_allow_extras(cls, obj: Any) -> Self: """Parse an object while allowing extra fields. - 'parse_obj' in Pydantic v1 is equivalent to 'model_validate' in Pydantic v2. - However, 'parse_obj_allow_extras' in v1 is not directly available in v2. - - `model_validate(..., strict=False)` does not work because it still raises errors on extra fields; - strict=False only affects type coercion and validation strictness, not extra field handling. - - This method temporarily modifies the model's configuration to allow extra fields + This method temporarily modifies the model's configuration to allow extra fields. TODO: Do we really need to support this behaviour? Consider removing this method in future. """ - if PYDANTIC_VERSION_2: - try: - with unittest.mock.patch.dict( - cls.model_config, # type: ignore - {"extra": "allow"}, - clear=False, - ): - cls.model_rebuild(force=True) # type: ignore - return cls.model_validate(obj) - finally: + try: + with unittest.mock.patch.dict( + cls.model_config, # type: ignore + {"extra": "allow"}, + clear=False, + ): cls.model_rebuild(force=True) # type: ignore - else: - with unittest.mock.patch.object(cls.Config, "extra", pydantic.Extra.allow): return cls.model_validate(obj) + finally: + cls.model_rebuild(force=True) # type: ignore class PermissiveConfigModel(ConfigModel): @@ -175,21 +155,14 @@ class PermissiveConfigModel(ConfigModel): # It is usually used for argument bags that are passed through to third-party libraries. class Config: - if PYDANTIC_VERSION_2: # noqa: SIM108 - extra = "allow" - else: - extra = Extra.allow + extra = "allow" class ConnectionModel(BaseModel): """Represents the config associated with a connection""" class Config: - if PYDANTIC_VERSION_2: - extra = "allow" - else: - extra = Extra.allow - underscore_attrs_are_private = True + extra = "allow" class TransformerSemantics(ConfigEnum): diff --git a/metadata-ingestion/src/datahub/configuration/pydantic_migration_helpers.py b/metadata-ingestion/src/datahub/configuration/pydantic_migration_helpers.py deleted file mode 100644 index db0e8e46930b8c..00000000000000 --- a/metadata-ingestion/src/datahub/configuration/pydantic_migration_helpers.py +++ /dev/null @@ -1,52 +0,0 @@ -import pydantic.version -from packaging.version import Version - -_pydantic_version = Version(pydantic.version.VERSION) - -PYDANTIC_VERSION_2 = _pydantic_version >= Version("2.0") - -# The pydantic.Discriminator type was added in v2.5.0. -# https://docs.pydantic.dev/latest/changelog/#v250-2023-11-13 -PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR = _pydantic_version >= Version("2.5.0") - - -if PYDANTIC_VERSION_2: - from pydantic import BaseModel as GenericModel - from pydantic.v1 import ( # type: ignore - BaseModel as v1_BaseModel, - Extra as v1_Extra, - Field as v1_Field, - root_validator as v1_root_validator, - validator as v1_validator, - ) -else: - from pydantic import ( # type: ignore - BaseModel as v1_BaseModel, - Extra as v1_Extra, - Field as v1_Field, - root_validator as v1_root_validator, - validator as v1_validator, - ) - from pydantic.generics import GenericModel # type: ignore - - -class v1_ConfigModel(v1_BaseModel): - """A simplified variant of our main ConfigModel class. - - This one only uses pydantic v1 features. - """ - - class Config: - extra = v1_Extra.forbid - underscore_attrs_are_private = True - - -__all__ = [ - "PYDANTIC_VERSION_2", - "PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR", - "GenericModel", - "v1_ConfigModel", - "v1_Field", - "v1_root_validator", - "v1_validator", -] diff --git a/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py b/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py index dcda990cbe9936..b6d68c347331ae 100644 --- a/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py +++ b/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py @@ -7,7 +7,6 @@ from pydantic.fields import Field from datahub.configuration.common import ConfigModel -from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2 from datahub.ingestion.glossary.classifier import Classifier from datahub.utilities.str_enum import StrEnum @@ -51,10 +50,7 @@ class ValuesFactorConfig(ConfigModel): class PredictionFactorsAndWeights(ConfigModel): class Config: - if PYDANTIC_VERSION_2: - populate_by_name = True - else: - allow_population_by_field_name = True + populate_by_name = True Name: float = Field(alias="name") Description: float = Field(alias="description") @@ -64,10 +60,7 @@ class Config: class InfoTypeConfig(ConfigModel): class Config: - if PYDANTIC_VERSION_2: - populate_by_name = True - else: - allow_population_by_field_name = True + populate_by_name = True Prediction_Factors_and_Weights: PredictionFactorsAndWeights = Field( description="Factors and their weights to consider when predicting info types", diff --git a/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py b/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py index b1f050b51d25c1..55ca9226053940 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py @@ -79,7 +79,7 @@ def ensure_field_level_settings_are_normalized(self) -> "DataLakeProfilerConfig" # Disable all field-level metrics. if self.profile_table_level_only: - for field_name in self.__fields__: + for field_name in self.model_fields: if field_name.startswith("include_field_"): setattr(self, field_name, False) diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py index 4e7c09f5a94059..9eb9e073958590 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py @@ -1,652 +1,652 @@ -import logging -from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional - -from datahub.emitter.mce_builder import ( - make_data_platform_urn, - make_dataset_urn_with_platform_instance, -) -from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.ingestion.api.common import PipelineContext -from datahub.ingestion.api.decorators import ( - SupportStatus, - capability, - config_class, - platform_name, - support_status, -) -from datahub.ingestion.api.source import ( - MetadataWorkUnitProcessor, - SourceCapability, - SourceReport, -) -from datahub.ingestion.api.workunit import MetadataWorkUnit -from datahub.ingestion.source.common.subtypes import SourceCapabilityModifier -from datahub.ingestion.source.dremio.dremio_api import ( - DremioAPIOperations, - DremioEdition, -) -from datahub.ingestion.source.dremio.dremio_aspects import DremioAspects -from datahub.ingestion.source.dremio.dremio_config import ( - DremioSourceConfig, - DremioSourceMapping, -) -from datahub.ingestion.source.dremio.dremio_datahub_source_mapping import ( - DremioToDataHubSourceTypeMapping, -) -from datahub.ingestion.source.dremio.dremio_entities import ( - DremioCatalog, - DremioContainer, - DremioDataset, - DremioDatasetType, - DremioGlossaryTerm, - DremioQuery, - DremioSourceContainer, -) -from datahub.ingestion.source.dremio.dremio_profiling import DremioProfiler -from datahub.ingestion.source.dremio.dremio_reporting import DremioSourceReport -from datahub.ingestion.source.state.stale_entity_removal_handler import ( - StaleEntityRemovalHandler, -) -from datahub.ingestion.source.state.stateful_ingestion_base import ( - StatefulIngestionSourceBase, -) -from datahub.ingestion.source_report.ingestion_stage import ( - LINEAGE_EXTRACTION, - METADATA_EXTRACTION, - PROFILING, -) -from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( - DatasetLineageTypeClass, - UpstreamClass, - UpstreamLineage, -) -from datahub.metadata.schema_classes import SchemaMetadataClass -from datahub.metadata.urns import CorpUserUrn -from datahub.sql_parsing.sql_parsing_aggregator import ( - KnownQueryLineageInfo, - ObservedQuery, - SqlParsingAggregator, -) - -logger = logging.getLogger(__name__) - - -@dataclass -class DremioSourceMapEntry: - platform: str - source_name: str - dremio_source_category: str - root_path: str = "" - database_name: str = "" - platform_instance: Optional[str] = None - env: Optional[str] = None - - -@platform_name("Dremio") -@config_class(DremioSourceConfig) -@support_status(SupportStatus.CERTIFIED) -@capability( - SourceCapability.CONTAINERS, - "Enabled by default", - subtype_modifier=[ - SourceCapabilityModifier.DREMIO_SPACE, - SourceCapabilityModifier.DREMIO_SOURCE, - ], -) -@capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration") -@capability(SourceCapability.DESCRIPTIONS, "Enabled by default") -@capability(SourceCapability.DOMAINS, "Supported via the `domain` config field") -@capability( - SourceCapability.LINEAGE_COARSE, - "Enabled by default", - subtype_modifier=[ - SourceCapabilityModifier.TABLE, - ], -) -@capability( - SourceCapability.LINEAGE_FINE, - "Extract column-level lineage", - subtype_modifier=[ - SourceCapabilityModifier.TABLE, - ], -) -@capability(SourceCapability.OWNERSHIP, "Enabled by default") -@capability(SourceCapability.PLATFORM_INSTANCE, "Enabled by default") -@capability(SourceCapability.USAGE_STATS, "Enabled by default to get usage stats") -class DremioSource(StatefulIngestionSourceBase): - """ - This plugin integrates with Dremio to extract and ingest metadata into DataHub. - The following types of metadata are extracted: - - - Metadata for Spaces, Folders, Sources, and Datasets: - - Includes physical and virtual datasets, with detailed information about each dataset. - - Extracts metadata about Dremio's organizational hierarchy: Spaces (top-level), Folders (sub-level), and Sources (external data connections). - - - Schema and Column Information: - - Column types and schema metadata associated with each physical and virtual dataset. - - Extracts column-level metadata, such as names, data types, and descriptions, if available. - - - Lineage Information: - - Dataset-level and column-level lineage tracking: - - Dataset-level lineage shows dependencies and relationships between physical and virtual datasets. - - Column-level lineage tracks transformations applied to individual columns across datasets. - - Lineage information helps trace the flow of data and transformations within Dremio. - - - Ownership and Glossary Terms: - - Metadata related to ownership of datasets, extracted from Dremio’s ownership model. - - Glossary terms and business metadata associated with datasets, providing additional context to the data. - - Note: Ownership information will only be available for the Cloud and Enterprise editions, it will not be available for the Community edition. - - - Optional SQL Profiling (if enabled): - - Table, row, and column statistics can be profiled and ingested via optional SQL queries. - - Extracts statistics about tables and columns, such as row counts and data distribution, for better insight into the dataset structure. - """ - - config: DremioSourceConfig - report: DremioSourceReport - - def __init__(self, config: DremioSourceConfig, ctx: PipelineContext): - super().__init__(config, ctx) - self.default_db = "dremio" - self.config = config - self.report = DremioSourceReport() - - # Set time window for query lineage extraction - self.report.window_start_time, self.report.window_end_time = ( - self.config.start_time, - self.config.end_time, - ) - - self.source_map: Dict[str, DremioSourceMapEntry] = dict() - - # Initialize API operations - dremio_api = DremioAPIOperations(self.config, self.report) - - # Initialize catalog - self.dremio_catalog = DremioCatalog(dremio_api) - - # Initialize aspects - self.dremio_aspects = DremioAspects( - platform=self.get_platform(), - domain=self.config.domain, - ingest_owner=self.config.ingest_owner, - platform_instance=self.config.platform_instance, - env=self.config.env, - ui_url=dremio_api.ui_url, - ) - self.max_workers = config.max_workers - - self.sql_parsing_aggregator = SqlParsingAggregator( - platform=make_data_platform_urn(self.get_platform()), - platform_instance=self.config.platform_instance, - env=self.config.env, - graph=self.ctx.graph, - generate_usage_statistics=True, - generate_operations=True, - usage_config=self.config.usage, - ) - self.report.sql_aggregator = self.sql_parsing_aggregator.report - - # For profiling - self.profiler = DremioProfiler(config, self.report, dremio_api) - - @classmethod - def create(cls, config_dict: Dict, ctx: PipelineContext) -> "DremioSource": - config = DremioSourceConfig.parse_obj(config_dict) - return cls(config, ctx) - - def get_platform(self) -> str: - return "dremio" - - def _build_source_map(self) -> Dict[str, DremioSourceMapEntry]: - dremio_sources = list(self.dremio_catalog.get_sources()) - source_mappings_config = self.config.source_mappings or [] - - source_map = build_dremio_source_map(dremio_sources, source_mappings_config) - logger.info(f"Full source map: {source_map}") - - return source_map - - def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: - return [ - *super().get_workunit_processors(), - StaleEntityRemovalHandler.create( - self, self.config, self.ctx - ).workunit_processor, - ] - - def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: - """ - Internal method to generate workunits for Dremio metadata. - """ - - self.source_map = self._build_source_map() - - with self.report.new_stage(METADATA_EXTRACTION): - # Process Containers - containers = self.dremio_catalog.get_containers() - for container in containers: - try: - yield from self.process_container(container) - logger.info( - f"Dremio container {container.container_name} emitted successfully" - ) - except Exception as exc: - self.report.num_containers_failed += 1 - self.report.report_failure( - message="Failed to process Dremio container", - context=f"{'.'.join(container.path)}.{container.container_name}", - exc=exc, - ) - - # Process Datasets - for dataset_info in self.dremio_catalog.get_datasets(): - try: - yield from self.process_dataset(dataset_info) - logger.info( - f"Dremio dataset {'.'.join(dataset_info.path)}.{dataset_info.resource_name} emitted successfully" - ) - except Exception as exc: - self.report.num_datasets_failed += 1 # Increment failed datasets - self.report.report_failure( - message="Failed to process Dremio dataset", - context=f"{'.'.join(dataset_info.path)}.{dataset_info.resource_name}", - exc=exc, - ) - - # Process Glossary Terms using streaming - for glossary_term in self.dremio_catalog.get_glossary_terms(): - try: - yield from self.process_glossary_term(glossary_term) - except Exception as exc: - self.report.report_failure( - message="Failed to process Glossary terms", - context=f"{glossary_term.glossary_term}", - exc=exc, - ) - - # Optionally Process Query Lineage - if self.config.include_query_lineage: - with self.report.new_stage(LINEAGE_EXTRACTION): - self.get_query_lineage_workunits() - - # Generate workunit for aggregated SQL parsing results - for mcp in self.sql_parsing_aggregator.gen_metadata(): - yield mcp.as_workunit() - - # Profiling - if self.config.is_profiling_enabled(): - with ( - self.report.new_stage(PROFILING), - ThreadPoolExecutor( - max_workers=self.config.profiling.max_workers - ) as executor, - ): - # Collect datasets for profiling - datasets_for_profiling = list(self.dremio_catalog.get_datasets()) - future_to_dataset = { - executor.submit(self.generate_profiles, dataset): dataset - for dataset in datasets_for_profiling - } - - for future in as_completed(future_to_dataset): - dataset_info = future_to_dataset[future] - try: - yield from future.result() - except Exception as exc: - self.report.profiling_skipped_other[ - dataset_info.resource_name - ] += 1 - self.report.report_failure( - message="Failed to profile dataset", - context=f"{'.'.join(dataset_info.path)}.{dataset_info.resource_name}", - exc=exc, - ) - - def process_container( - self, container_info: DremioContainer - ) -> Iterable[MetadataWorkUnit]: - """ - Process a Dremio container and generate metadata workunits. - """ - container_urn = self.dremio_aspects.get_container_urn( - path=container_info.path, name=container_info.container_name - ) - - yield from self.dremio_aspects.populate_container_mcp( - container_urn, container_info - ) - - def process_dataset( - self, dataset_info: DremioDataset - ) -> Iterable[MetadataWorkUnit]: - """ - Process a Dremio dataset and generate metadata workunits. - """ - - schema_str = ".".join(dataset_info.path) - - dataset_name = f"{schema_str}.{dataset_info.resource_name}".lower() - - self.report.report_entity_scanned(dataset_name, dataset_info.dataset_type.value) - if not self.config.dataset_pattern.allowed(dataset_name): - self.report.report_dropped(dataset_name) - return - - dataset_urn = make_dataset_urn_with_platform_instance( - platform=make_data_platform_urn(self.get_platform()), - name=f"dremio.{dataset_name}", - env=self.config.env, - platform_instance=self.config.platform_instance, - ) - - for dremio_mcp in self.dremio_aspects.populate_dataset_mcp( - dataset_urn, dataset_info - ): - yield dremio_mcp - # Check if the emitted aspect is SchemaMetadataClass - if isinstance( - dremio_mcp.metadata, MetadataChangeProposalWrapper - ) and isinstance(dremio_mcp.metadata.aspect, SchemaMetadataClass): - self.sql_parsing_aggregator.register_schema( - urn=dataset_urn, - schema=dremio_mcp.metadata.aspect, - ) - - if dataset_info.dataset_type == DremioDatasetType.VIEW: - if ( - self.dremio_catalog.edition == DremioEdition.ENTERPRISE - and dataset_info.parents - ): - yield from self.generate_view_lineage( - parents=dataset_info.parents, - dataset_urn=dataset_urn, - ) - - if dataset_info.sql_definition: - self.sql_parsing_aggregator.add_view_definition( - view_urn=dataset_urn, - view_definition=dataset_info.sql_definition, - default_db=self.default_db, - default_schema=dataset_info.default_schema, - ) - - elif dataset_info.dataset_type == DremioDatasetType.TABLE: - dremio_source = dataset_info.path[0] if dataset_info.path else None - - if dremio_source: - upstream_urn = self._map_dremio_dataset_to_urn( - dremio_source=dremio_source, - dremio_path=dataset_info.path, - dremio_dataset=dataset_info.resource_name, - ) - logger.debug(f"Upstream dataset for {dataset_urn}: {upstream_urn}") - - if upstream_urn: - upstream_lineage = UpstreamLineage( - upstreams=[ - UpstreamClass( - dataset=upstream_urn, - type=DatasetLineageTypeClass.COPY, - ) - ] - ) - mcp = MetadataChangeProposalWrapper( - entityUrn=dataset_urn, - aspect=upstream_lineage, - ) - yield mcp.as_workunit() - self.sql_parsing_aggregator.add_known_lineage_mapping( - upstream_urn=upstream_urn, - downstream_urn=dataset_urn, - lineage_type=DatasetLineageTypeClass.COPY, - ) - - def process_glossary_term( - self, glossary_term_info: DremioGlossaryTerm - ) -> Iterable[MetadataWorkUnit]: - """ - Process a Dremio container and generate metadata workunits. - """ - - yield from self.dremio_aspects.populate_glossary_term_mcp(glossary_term_info) - - def generate_profiles( - self, dataset_info: DremioDataset - ) -> Iterable[MetadataWorkUnit]: - schema_str = ".".join(dataset_info.path) - dataset_name = f"{schema_str}.{dataset_info.resource_name}".lower() - dataset_urn = make_dataset_urn_with_platform_instance( - platform=make_data_platform_urn(self.get_platform()), - name=f"dremio.{dataset_name}", - env=self.config.env, - platform_instance=self.config.platform_instance, - ) - yield from self.profiler.get_workunits(dataset_info, dataset_urn) - - def generate_view_lineage( - self, dataset_urn: str, parents: List[str] - ) -> Iterable[MetadataWorkUnit]: - """ - Generate lineage information for views. - """ - upstream_urns = [ - make_dataset_urn_with_platform_instance( - platform=make_data_platform_urn(self.get_platform()), - name=f"dremio.{upstream_table.lower()}", - env=self.config.env, - platform_instance=self.config.platform_instance, - ) - for upstream_table in parents - ] - - lineage = UpstreamLineage( - upstreams=[ - UpstreamClass( - dataset=upstream_urn, - type=DatasetLineageTypeClass.VIEW, - ) - for upstream_urn in upstream_urns - ] - ) - mcp = MetadataChangeProposalWrapper( - entityUrn=dataset_urn, - aspect=lineage, - ) - - for upstream_urn in upstream_urns: - self.sql_parsing_aggregator.add_known_lineage_mapping( - upstream_urn=upstream_urn, - downstream_urn=dataset_urn, - lineage_type=DatasetLineageTypeClass.VIEW, - ) - - yield MetadataWorkUnit(id=f"{dataset_urn}-upstreamLineage", mcp=mcp) - - def get_query_lineage_workunits(self) -> None: - """ - Process query lineage information. - """ - - queries = self.dremio_catalog.get_queries() - - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - future_to_query = { - executor.submit(self.process_query, query): query for query in queries - } - - for future in as_completed(future_to_query): - query = future_to_query[future] - try: - future.result() - except Exception as exc: - self.report.report_failure( - message="Failed to process dremio query", - context=f"{query.job_id}: {exc}", - exc=exc, - ) - - def process_query(self, query: DremioQuery) -> None: - """ - Process a single Dremio query for lineage information. - """ - - if query.query and query.affected_dataset: - upstream_urns = [ - make_dataset_urn_with_platform_instance( - platform=make_data_platform_urn(self.get_platform()), - name=f"dremio.{ds.lower()}", - env=self.config.env, - platform_instance=self.config.platform_instance, - ) - for ds in query.queried_datasets - ] - - downstream_urn = make_dataset_urn_with_platform_instance( - platform=make_data_platform_urn(self.get_platform()), - name=f"dremio.{query.affected_dataset.lower()}", - env=self.config.env, - platform_instance=self.config.platform_instance, - ) - - # Add query to SqlParsingAggregator - self.sql_parsing_aggregator.add_known_query_lineage( - KnownQueryLineageInfo( - query_text=query.query, - upstreams=upstream_urns, - downstream=downstream_urn, - ), - merge_lineage=True, - ) - - # Add observed query - self.sql_parsing_aggregator.add_observed_query( - ObservedQuery( - query=query.query, - timestamp=query.submitted_ts, - user=CorpUserUrn(username=query.username), - default_db=self.default_db, - ) - ) - - def _map_dremio_dataset_to_urn( - self, - dremio_source: str, - dremio_path: List[str], - dremio_dataset: str, - ) -> Optional[str]: - """ - Map a Dremio dataset to a DataHub URN. - """ - mapping = self.source_map.get(dremio_source.lower()) - if not mapping: - return None - - platform = mapping.platform - if not platform: - return None - - platform_instance = mapping.platform_instance - env = mapping.env or self.config.env - - root_path = "" - database_name = "" - - if mapping.dremio_source_category == "file_object_storage": - if mapping.root_path: - root_path = f"{mapping.root_path[1:]}/" - dremio_dataset = f"{root_path}{'/'.join(dremio_path[1:])}/{dremio_dataset}" - else: - if mapping.database_name: - database_name = f"{mapping.database_name}." - dremio_dataset = ( - f"{database_name}{'.'.join(dremio_path[1:])}.{dremio_dataset}" - ) - - if platform_instance: - return make_dataset_urn_with_platform_instance( - platform=platform.lower(), - name=dremio_dataset, - platform_instance=platform_instance, - env=env, - ) - - return make_dataset_urn_with_platform_instance( - platform=platform.lower(), - name=dremio_dataset, - platform_instance=None, - env=env, - ) - - def get_report(self) -> SourceReport: - """ - Get the source report. - """ - return self.report - - -def build_dremio_source_map( - dremio_sources: Iterable[DremioSourceContainer], - source_mappings_config: List[DremioSourceMapping], -) -> Dict[str, DremioSourceMapEntry]: - """ - Builds a source mapping dictionary to support external lineage generation across - multiple Dremio sources, based on provided configuration mappings. - - This method operates as follows: - - Returns: - Dict[str, Dict]: A dictionary (`source_map`) where each key is a source name - (lowercased) and each value is another entry containing: - - `platform`: The source platform. - - `source_name`: The source name. - - `dremio_source_category`: The type mapped to DataHub, - e.g., "database", "folder". - - Optional `root_path`, `database_name`, `platform_instance`, - and `env` if provided in the configuration. - Example: - This method is used internally within the class to generate mappings before - creating cross-platform lineage. - - """ - source_map = {} - for source in dremio_sources: - current_source_name = source.container_name - - source_type = source.dremio_source_type.lower() - source_category = DremioToDataHubSourceTypeMapping.get_category(source_type) - datahub_platform = DremioToDataHubSourceTypeMapping.get_datahub_platform( - source_type - ) - root_path = source.root_path.lower() if source.root_path else "" - database_name = source.database_name.lower() if source.database_name else "" - source_present = False - - for mapping in source_mappings_config: - if mapping.source_name.lower() == current_source_name.lower(): - source_map[current_source_name.lower()] = DremioSourceMapEntry( - platform=mapping.platform, - source_name=mapping.source_name, - dremio_source_category=source_category, - root_path=root_path, - database_name=database_name, - platform_instance=mapping.platform_instance, - env=mapping.env, - ) - source_present = True - break - - if not source_present: - source_map[current_source_name.lower()] = DremioSourceMapEntry( - platform=datahub_platform, - source_name=current_source_name, - dremio_source_category=source_category, - root_path=root_path, - database_name=database_name, - platform_instance=None, - env=None, - ) - - return source_map +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional + +from datahub.emitter.mce_builder import ( + make_data_platform_urn, + make_dataset_urn_with_platform_instance, +) +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.api.decorators import ( + SupportStatus, + capability, + config_class, + platform_name, + support_status, +) +from datahub.ingestion.api.source import ( + MetadataWorkUnitProcessor, + SourceCapability, + SourceReport, +) +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.common.subtypes import SourceCapabilityModifier +from datahub.ingestion.source.dremio.dremio_api import ( + DremioAPIOperations, + DremioEdition, +) +from datahub.ingestion.source.dremio.dremio_aspects import DremioAspects +from datahub.ingestion.source.dremio.dremio_config import ( + DremioSourceConfig, + DremioSourceMapping, +) +from datahub.ingestion.source.dremio.dremio_datahub_source_mapping import ( + DremioToDataHubSourceTypeMapping, +) +from datahub.ingestion.source.dremio.dremio_entities import ( + DremioCatalog, + DremioContainer, + DremioDataset, + DremioDatasetType, + DremioGlossaryTerm, + DremioQuery, + DremioSourceContainer, +) +from datahub.ingestion.source.dremio.dremio_profiling import DremioProfiler +from datahub.ingestion.source.dremio.dremio_reporting import DremioSourceReport +from datahub.ingestion.source.state.stale_entity_removal_handler import ( + StaleEntityRemovalHandler, +) +from datahub.ingestion.source.state.stateful_ingestion_base import ( + StatefulIngestionSourceBase, +) +from datahub.ingestion.source_report.ingestion_stage import ( + LINEAGE_EXTRACTION, + METADATA_EXTRACTION, + PROFILING, +) +from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( + DatasetLineageTypeClass, + UpstreamClass, + UpstreamLineage, +) +from datahub.metadata.schema_classes import SchemaMetadataClass +from datahub.metadata.urns import CorpUserUrn +from datahub.sql_parsing.sql_parsing_aggregator import ( + KnownQueryLineageInfo, + ObservedQuery, + SqlParsingAggregator, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class DremioSourceMapEntry: + platform: str + source_name: str + dremio_source_category: str + root_path: str = "" + database_name: str = "" + platform_instance: Optional[str] = None + env: Optional[str] = None + + +@platform_name("Dremio") +@config_class(DremioSourceConfig) +@support_status(SupportStatus.CERTIFIED) +@capability( + SourceCapability.CONTAINERS, + "Enabled by default", + subtype_modifier=[ + SourceCapabilityModifier.DREMIO_SPACE, + SourceCapabilityModifier.DREMIO_SOURCE, + ], +) +@capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration") +@capability(SourceCapability.DESCRIPTIONS, "Enabled by default") +@capability(SourceCapability.DOMAINS, "Supported via the `domain` config field") +@capability( + SourceCapability.LINEAGE_COARSE, + "Enabled by default", + subtype_modifier=[ + SourceCapabilityModifier.TABLE, + ], +) +@capability( + SourceCapability.LINEAGE_FINE, + "Extract column-level lineage", + subtype_modifier=[ + SourceCapabilityModifier.TABLE, + ], +) +@capability(SourceCapability.OWNERSHIP, "Enabled by default") +@capability(SourceCapability.PLATFORM_INSTANCE, "Enabled by default") +@capability(SourceCapability.USAGE_STATS, "Enabled by default to get usage stats") +class DremioSource(StatefulIngestionSourceBase): + """ + This plugin integrates with Dremio to extract and ingest metadata into DataHub. + The following types of metadata are extracted: + + - Metadata for Spaces, Folders, Sources, and Datasets: + - Includes physical and virtual datasets, with detailed information about each dataset. + - Extracts metadata about Dremio's organizational hierarchy: Spaces (top-level), Folders (sub-level), and Sources (external data connections). + + - Schema and Column Information: + - Column types and schema metadata associated with each physical and virtual dataset. + - Extracts column-level metadata, such as names, data types, and descriptions, if available. + + - Lineage Information: + - Dataset-level and column-level lineage tracking: + - Dataset-level lineage shows dependencies and relationships between physical and virtual datasets. + - Column-level lineage tracks transformations applied to individual columns across datasets. + - Lineage information helps trace the flow of data and transformations within Dremio. + + - Ownership and Glossary Terms: + - Metadata related to ownership of datasets, extracted from Dremio’s ownership model. + - Glossary terms and business metadata associated with datasets, providing additional context to the data. + - Note: Ownership information will only be available for the Cloud and Enterprise editions, it will not be available for the Community edition. + + - Optional SQL Profiling (if enabled): + - Table, row, and column statistics can be profiled and ingested via optional SQL queries. + - Extracts statistics about tables and columns, such as row counts and data distribution, for better insight into the dataset structure. + """ + + config: DremioSourceConfig + report: DremioSourceReport + + def __init__(self, config: DremioSourceConfig, ctx: PipelineContext): + super().__init__(config, ctx) + self.default_db = "dremio" + self.config = config + self.report = DremioSourceReport() + + # Set time window for query lineage extraction + self.report.window_start_time, self.report.window_end_time = ( + self.config.start_time, + self.config.end_time, + ) + + self.source_map: Dict[str, DremioSourceMapEntry] = dict() + + # Initialize API operations + dremio_api = DremioAPIOperations(self.config, self.report) + + # Initialize catalog + self.dremio_catalog = DremioCatalog(dremio_api) + + # Initialize aspects + self.dremio_aspects = DremioAspects( + platform=self.get_platform(), + domain=self.config.domain, + ingest_owner=self.config.ingest_owner, + platform_instance=self.config.platform_instance, + env=self.config.env, + ui_url=dremio_api.ui_url, + ) + self.max_workers = config.max_workers + + self.sql_parsing_aggregator = SqlParsingAggregator( + platform=make_data_platform_urn(self.get_platform()), + platform_instance=self.config.platform_instance, + env=self.config.env, + graph=self.ctx.graph, + generate_usage_statistics=True, + generate_operations=True, + usage_config=self.config.usage, + ) + self.report.sql_aggregator = self.sql_parsing_aggregator.report + + # For profiling + self.profiler = DremioProfiler(config, self.report, dremio_api) + + @classmethod + def create(cls, config_dict: Dict, ctx: PipelineContext) -> "DremioSource": + config = DremioSourceConfig.model_validate(config_dict) + return cls(config, ctx) + + def get_platform(self) -> str: + return "dremio" + + def _build_source_map(self) -> Dict[str, DremioSourceMapEntry]: + dremio_sources = list(self.dremio_catalog.get_sources()) + source_mappings_config = self.config.source_mappings or [] + + source_map = build_dremio_source_map(dremio_sources, source_mappings_config) + logger.info(f"Full source map: {source_map}") + + return source_map + + def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: + return [ + *super().get_workunit_processors(), + StaleEntityRemovalHandler.create( + self, self.config, self.ctx + ).workunit_processor, + ] + + def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: + """ + Internal method to generate workunits for Dremio metadata. + """ + + self.source_map = self._build_source_map() + + with self.report.new_stage(METADATA_EXTRACTION): + # Process Containers + containers = self.dremio_catalog.get_containers() + for container in containers: + try: + yield from self.process_container(container) + logger.info( + f"Dremio container {container.container_name} emitted successfully" + ) + except Exception as exc: + self.report.num_containers_failed += 1 + self.report.report_failure( + message="Failed to process Dremio container", + context=f"{'.'.join(container.path)}.{container.container_name}", + exc=exc, + ) + + # Process Datasets + for dataset_info in self.dremio_catalog.get_datasets(): + try: + yield from self.process_dataset(dataset_info) + logger.info( + f"Dremio dataset {'.'.join(dataset_info.path)}.{dataset_info.resource_name} emitted successfully" + ) + except Exception as exc: + self.report.num_datasets_failed += 1 # Increment failed datasets + self.report.report_failure( + message="Failed to process Dremio dataset", + context=f"{'.'.join(dataset_info.path)}.{dataset_info.resource_name}", + exc=exc, + ) + + # Process Glossary Terms using streaming + for glossary_term in self.dremio_catalog.get_glossary_terms(): + try: + yield from self.process_glossary_term(glossary_term) + except Exception as exc: + self.report.report_failure( + message="Failed to process Glossary terms", + context=f"{glossary_term.glossary_term}", + exc=exc, + ) + + # Optionally Process Query Lineage + if self.config.include_query_lineage: + with self.report.new_stage(LINEAGE_EXTRACTION): + self.get_query_lineage_workunits() + + # Generate workunit for aggregated SQL parsing results + for mcp in self.sql_parsing_aggregator.gen_metadata(): + yield mcp.as_workunit() + + # Profiling + if self.config.is_profiling_enabled(): + with ( + self.report.new_stage(PROFILING), + ThreadPoolExecutor( + max_workers=self.config.profiling.max_workers + ) as executor, + ): + # Collect datasets for profiling + datasets_for_profiling = list(self.dremio_catalog.get_datasets()) + future_to_dataset = { + executor.submit(self.generate_profiles, dataset): dataset + for dataset in datasets_for_profiling + } + + for future in as_completed(future_to_dataset): + dataset_info = future_to_dataset[future] + try: + yield from future.result() + except Exception as exc: + self.report.profiling_skipped_other[ + dataset_info.resource_name + ] += 1 + self.report.report_failure( + message="Failed to profile dataset", + context=f"{'.'.join(dataset_info.path)}.{dataset_info.resource_name}", + exc=exc, + ) + + def process_container( + self, container_info: DremioContainer + ) -> Iterable[MetadataWorkUnit]: + """ + Process a Dremio container and generate metadata workunits. + """ + container_urn = self.dremio_aspects.get_container_urn( + path=container_info.path, name=container_info.container_name + ) + + yield from self.dremio_aspects.populate_container_mcp( + container_urn, container_info + ) + + def process_dataset( + self, dataset_info: DremioDataset + ) -> Iterable[MetadataWorkUnit]: + """ + Process a Dremio dataset and generate metadata workunits. + """ + + schema_str = ".".join(dataset_info.path) + + dataset_name = f"{schema_str}.{dataset_info.resource_name}".lower() + + self.report.report_entity_scanned(dataset_name, dataset_info.dataset_type.value) + if not self.config.dataset_pattern.allowed(dataset_name): + self.report.report_dropped(dataset_name) + return + + dataset_urn = make_dataset_urn_with_platform_instance( + platform=make_data_platform_urn(self.get_platform()), + name=f"dremio.{dataset_name}", + env=self.config.env, + platform_instance=self.config.platform_instance, + ) + + for dremio_mcp in self.dremio_aspects.populate_dataset_mcp( + dataset_urn, dataset_info + ): + yield dremio_mcp + # Check if the emitted aspect is SchemaMetadataClass + if isinstance( + dremio_mcp.metadata, MetadataChangeProposalWrapper + ) and isinstance(dremio_mcp.metadata.aspect, SchemaMetadataClass): + self.sql_parsing_aggregator.register_schema( + urn=dataset_urn, + schema=dremio_mcp.metadata.aspect, + ) + + if dataset_info.dataset_type == DremioDatasetType.VIEW: + if ( + self.dremio_catalog.edition == DremioEdition.ENTERPRISE + and dataset_info.parents + ): + yield from self.generate_view_lineage( + parents=dataset_info.parents, + dataset_urn=dataset_urn, + ) + + if dataset_info.sql_definition: + self.sql_parsing_aggregator.add_view_definition( + view_urn=dataset_urn, + view_definition=dataset_info.sql_definition, + default_db=self.default_db, + default_schema=dataset_info.default_schema, + ) + + elif dataset_info.dataset_type == DremioDatasetType.TABLE: + dremio_source = dataset_info.path[0] if dataset_info.path else None + + if dremio_source: + upstream_urn = self._map_dremio_dataset_to_urn( + dremio_source=dremio_source, + dremio_path=dataset_info.path, + dremio_dataset=dataset_info.resource_name, + ) + logger.debug(f"Upstream dataset for {dataset_urn}: {upstream_urn}") + + if upstream_urn: + upstream_lineage = UpstreamLineage( + upstreams=[ + UpstreamClass( + dataset=upstream_urn, + type=DatasetLineageTypeClass.COPY, + ) + ] + ) + mcp = MetadataChangeProposalWrapper( + entityUrn=dataset_urn, + aspect=upstream_lineage, + ) + yield mcp.as_workunit() + self.sql_parsing_aggregator.add_known_lineage_mapping( + upstream_urn=upstream_urn, + downstream_urn=dataset_urn, + lineage_type=DatasetLineageTypeClass.COPY, + ) + + def process_glossary_term( + self, glossary_term_info: DremioGlossaryTerm + ) -> Iterable[MetadataWorkUnit]: + """ + Process a Dremio container and generate metadata workunits. + """ + + yield from self.dremio_aspects.populate_glossary_term_mcp(glossary_term_info) + + def generate_profiles( + self, dataset_info: DremioDataset + ) -> Iterable[MetadataWorkUnit]: + schema_str = ".".join(dataset_info.path) + dataset_name = f"{schema_str}.{dataset_info.resource_name}".lower() + dataset_urn = make_dataset_urn_with_platform_instance( + platform=make_data_platform_urn(self.get_platform()), + name=f"dremio.{dataset_name}", + env=self.config.env, + platform_instance=self.config.platform_instance, + ) + yield from self.profiler.get_workunits(dataset_info, dataset_urn) + + def generate_view_lineage( + self, dataset_urn: str, parents: List[str] + ) -> Iterable[MetadataWorkUnit]: + """ + Generate lineage information for views. + """ + upstream_urns = [ + make_dataset_urn_with_platform_instance( + platform=make_data_platform_urn(self.get_platform()), + name=f"dremio.{upstream_table.lower()}", + env=self.config.env, + platform_instance=self.config.platform_instance, + ) + for upstream_table in parents + ] + + lineage = UpstreamLineage( + upstreams=[ + UpstreamClass( + dataset=upstream_urn, + type=DatasetLineageTypeClass.VIEW, + ) + for upstream_urn in upstream_urns + ] + ) + mcp = MetadataChangeProposalWrapper( + entityUrn=dataset_urn, + aspect=lineage, + ) + + for upstream_urn in upstream_urns: + self.sql_parsing_aggregator.add_known_lineage_mapping( + upstream_urn=upstream_urn, + downstream_urn=dataset_urn, + lineage_type=DatasetLineageTypeClass.VIEW, + ) + + yield MetadataWorkUnit(id=f"{dataset_urn}-upstreamLineage", mcp=mcp) + + def get_query_lineage_workunits(self) -> None: + """ + Process query lineage information. + """ + + queries = self.dremio_catalog.get_queries() + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_query = { + executor.submit(self.process_query, query): query for query in queries + } + + for future in as_completed(future_to_query): + query = future_to_query[future] + try: + future.result() + except Exception as exc: + self.report.report_failure( + message="Failed to process dremio query", + context=f"{query.job_id}: {exc}", + exc=exc, + ) + + def process_query(self, query: DremioQuery) -> None: + """ + Process a single Dremio query for lineage information. + """ + + if query.query and query.affected_dataset: + upstream_urns = [ + make_dataset_urn_with_platform_instance( + platform=make_data_platform_urn(self.get_platform()), + name=f"dremio.{ds.lower()}", + env=self.config.env, + platform_instance=self.config.platform_instance, + ) + for ds in query.queried_datasets + ] + + downstream_urn = make_dataset_urn_with_platform_instance( + platform=make_data_platform_urn(self.get_platform()), + name=f"dremio.{query.affected_dataset.lower()}", + env=self.config.env, + platform_instance=self.config.platform_instance, + ) + + # Add query to SqlParsingAggregator + self.sql_parsing_aggregator.add_known_query_lineage( + KnownQueryLineageInfo( + query_text=query.query, + upstreams=upstream_urns, + downstream=downstream_urn, + ), + merge_lineage=True, + ) + + # Add observed query + self.sql_parsing_aggregator.add_observed_query( + ObservedQuery( + query=query.query, + timestamp=query.submitted_ts, + user=CorpUserUrn(username=query.username), + default_db=self.default_db, + ) + ) + + def _map_dremio_dataset_to_urn( + self, + dremio_source: str, + dremio_path: List[str], + dremio_dataset: str, + ) -> Optional[str]: + """ + Map a Dremio dataset to a DataHub URN. + """ + mapping = self.source_map.get(dremio_source.lower()) + if not mapping: + return None + + platform = mapping.platform + if not platform: + return None + + platform_instance = mapping.platform_instance + env = mapping.env or self.config.env + + root_path = "" + database_name = "" + + if mapping.dremio_source_category == "file_object_storage": + if mapping.root_path: + root_path = f"{mapping.root_path[1:]}/" + dremio_dataset = f"{root_path}{'/'.join(dremio_path[1:])}/{dremio_dataset}" + else: + if mapping.database_name: + database_name = f"{mapping.database_name}." + dremio_dataset = ( + f"{database_name}{'.'.join(dremio_path[1:])}.{dremio_dataset}" + ) + + if platform_instance: + return make_dataset_urn_with_platform_instance( + platform=platform.lower(), + name=dremio_dataset, + platform_instance=platform_instance, + env=env, + ) + + return make_dataset_urn_with_platform_instance( + platform=platform.lower(), + name=dremio_dataset, + platform_instance=None, + env=env, + ) + + def get_report(self) -> SourceReport: + """ + Get the source report. + """ + return self.report + + +def build_dremio_source_map( + dremio_sources: Iterable[DremioSourceContainer], + source_mappings_config: List[DremioSourceMapping], +) -> Dict[str, DremioSourceMapEntry]: + """ + Builds a source mapping dictionary to support external lineage generation across + multiple Dremio sources, based on provided configuration mappings. + + This method operates as follows: + + Returns: + Dict[str, Dict]: A dictionary (`source_map`) where each key is a source name + (lowercased) and each value is another entry containing: + - `platform`: The source platform. + - `source_name`: The source name. + - `dremio_source_category`: The type mapped to DataHub, + e.g., "database", "folder". + - Optional `root_path`, `database_name`, `platform_instance`, + and `env` if provided in the configuration. + Example: + This method is used internally within the class to generate mappings before + creating cross-platform lineage. + + """ + source_map = {} + for source in dremio_sources: + current_source_name = source.container_name + + source_type = source.dremio_source_type.lower() + source_category = DremioToDataHubSourceTypeMapping.get_category(source_type) + datahub_platform = DremioToDataHubSourceTypeMapping.get_datahub_platform( + source_type + ) + root_path = source.root_path.lower() if source.root_path else "" + database_name = source.database_name.lower() if source.database_name else "" + source_present = False + + for mapping in source_mappings_config: + if mapping.source_name.lower() == current_source_name.lower(): + source_map[current_source_name.lower()] = DremioSourceMapEntry( + platform=mapping.platform, + source_name=mapping.source_name, + dremio_source_category=source_category, + root_path=root_path, + database_name=database_name, + platform_instance=mapping.platform_instance, + env=mapping.env, + ) + source_present = True + break + + if not source_present: + source_map[current_source_name.lower()] = DremioSourceMapEntry( + platform=datahub_platform, + source_name=current_source_name, + dremio_source_category=source_category, + root_path=root_path, + database_name=database_name, + platform_instance=None, + env=None, + ) + + return source_map diff --git a/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py b/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py index c593f36ffbc5d5..1211e39bed5547 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py @@ -265,7 +265,7 @@ def ensure_field_level_settings_are_normalized( def any_field_level_metrics_enabled(self) -> bool: return any( getattr(self, field_name) - for field_name in self.__fields__ + for field_name in self.model_fields if field_name.startswith("include_field_") ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py b/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py index b1f050b51d25c1..55ca9226053940 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py @@ -79,7 +79,7 @@ def ensure_field_level_settings_are_normalized(self) -> "DataLakeProfilerConfig" # Disable all field-level metrics. if self.profile_table_level_only: - for field_name in self.__fields__: + for field_name in self.model_fields: if field_name.startswith("include_field_"): setattr(self, field_name, False) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py index 332129872c4787..ef588e713f9951 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py @@ -227,7 +227,7 @@ def __init__(self, ctx: PipelineContext, config: SqlQueriesSourceConfig): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "SqlQueriesSource": - config = SqlQueriesSourceConfig.parse_obj(config_dict) + config = SqlQueriesSourceConfig.model_validate(config_dict) return cls(ctx, config) def get_report(self) -> SqlQueriesSourceReport: @@ -498,6 +498,6 @@ def create( # Set validation context for URN creation cls._validation_context = config try: - return cls.parse_obj(entry_dict) + return cls.model_validate(entry_dict) finally: cls._validation_context = None diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py index 15de4fbe56ecca..d64f18f33e8dec 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Generic, Optional, Type, TypeVar import pydantic -from pydantic import model_validator +from pydantic import BaseModel as GenericModel, model_validator from pydantic.fields import Field from datahub.configuration.common import ( @@ -12,7 +12,6 @@ DynamicTypedConfig, HiddenFromDocs, ) -from datahub.configuration.pydantic_migration_helpers import GenericModel from datahub.configuration.time_window_config import BaseTimeWindowConfig from datahub.configuration.validate_field_rename import pydantic_renamed_field from datahub.ingestion.api.common import PipelineContext diff --git a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/compiler.py b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/compiler.py index d14d39f6a38d44..5fe3df85d7b065 100644 --- a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/compiler.py +++ b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/compiler.py @@ -20,6 +20,7 @@ from datahub.api.entities.assertion.datahub_assertion import DataHubAssertion from datahub.api.entities.assertion.field_assertion import FieldValuesAssertion from datahub.api.entities.assertion.freshness_assertion import ( + CronFreshnessAssertion, FixedIntervalFreshnessAssertion, ) from datahub.emitter.mce_builder import make_assertion_urn @@ -110,7 +111,7 @@ def compile( result.status = "failure" result.report.report_failure( assertion_spec.get_id(), - f"Failed to compile assertion of type {assertion_spec.assertion.type} due to error: {e}", + f"Failed to compile assertion of type {assertion_spec.type} due to error: {e}", ) result.report.num_compile_failed += 1 if result.report.num_compile_succeeded > 0: @@ -141,27 +142,30 @@ def process_assertion(self, assertion: DataHubAssertion) -> Tuple[str, str]: # For field values assertion, metric is number or percentage of rows that do not satify # operator condition. # For remaining assertions, numeric metric is discernible in assertion definition itself. - metric_definition = self.metric_generator.metric_sql(assertion.assertion) + metric_definition = self.metric_generator.metric_sql(assertion) - if isinstance(assertion.assertion, FixedIntervalFreshnessAssertion): + if isinstance(assertion, FixedIntervalFreshnessAssertion): assertion_sql = self.metric_evaluator.operator_sql( LessThanOrEqualToOperator( type="less_than_or_equal_to", - value=assertion.assertion.lookback_interval.total_seconds(), + value=assertion.lookback_interval.total_seconds(), ), metric_definition, ) - elif isinstance(assertion.assertion, FieldValuesAssertion): + elif isinstance(assertion, CronFreshnessAssertion): + # CronFreshnessAssertion does not have an operator, skip operator_sql + assertion_sql = metric_definition + elif isinstance(assertion, FieldValuesAssertion): assertion_sql = self.metric_evaluator.operator_sql( LessThanOrEqualToOperator( type="less_than_or_equal_to", - value=assertion.assertion.failure_threshold.value, + value=assertion.failure_threshold.value, ), metric_definition, ) else: assertion_sql = self.metric_evaluator.operator_sql( - assertion.assertion.operator, metric_definition + assertion.operator, metric_definition ) dmf_name = get_dmf_name(assertion) @@ -169,35 +173,37 @@ def process_assertion(self, assertion: DataHubAssertion) -> Tuple[str, str]: args_create_dmf, args_add_dmf = get_dmf_args(assertion) - entity_name = get_entity_name(assertion.assertion) + entity_name = get_entity_name(assertion) - self._entity_schedule_history.setdefault( - assertion.assertion.entity, assertion.assertion.trigger - ) - if ( - assertion.assertion.entity in self._entity_schedule_history - and self._entity_schedule_history[assertion.assertion.entity] - != assertion.assertion.trigger + if assertion.trigger: + self._entity_schedule_history.setdefault( + assertion.entity, assertion.trigger + ) + if assertion.trigger and ( + assertion.entity in self._entity_schedule_history + and self._entity_schedule_history[assertion.entity] != assertion.trigger ): raise ValueError( "Assertions on same entity must have same schedules as of now." - f" Found different schedules on entity {assertion.assertion.entity} ->" - f" ({self._entity_schedule_history[assertion.assertion.entity].trigger})," - f" ({assertion.assertion.trigger.trigger})" + f" Found different schedules on entity {assertion.entity} ->" + f" ({self._entity_schedule_history[assertion.entity].trigger})," + f" ({assertion.trigger.trigger})" ) - dmf_schedule = get_dmf_schedule(assertion.assertion.trigger) + dmf_schedule = ( + get_dmf_schedule(assertion.trigger) if assertion.trigger else None + ) dmf_definition = self.dmf_handler.create_dmf( f"{dmf_schema_name}.{dmf_name}", args_create_dmf, - assertion.assertion.description - or f"Created via DataHub for assertion {make_assertion_urn(assertion.get_id())} of type {assertion.assertion.type}", + assertion.description + or f"Created via DataHub for assertion {make_assertion_urn(assertion.get_id())} of type {assertion.type}", assertion_sql, ) dmf_association = self.dmf_handler.add_dmf_to_table( f"{dmf_schema_name}.{dmf_name}", args_add_dmf, - dmf_schedule, + dmf_schedule or "", # type: ignore[arg-type] ".".join(entity_name), ) @@ -217,7 +223,7 @@ def get_dmf_args(assertion: DataHubAssertion) -> Tuple[str, str]: # So we fetch any one column from table's schema args_create_dmf = "ARGT TABLE({col_name} {col_type})" args_add_dmf = "{col_name}" - entity_schema = get_entity_schema(assertion.assertion) + entity_schema = get_entity_schema(assertion) if entity_schema: for col_dict in entity_schema: return args_create_dmf.format( diff --git a/metadata-ingestion/src/datahub/pydantic/compat.py b/metadata-ingestion/src/datahub/pydantic/compat.py deleted file mode 100644 index f6c1208ae85afd..00000000000000 --- a/metadata-ingestion/src/datahub/pydantic/compat.py +++ /dev/null @@ -1,58 +0,0 @@ -import functools -from typing import Any, Callable, Optional, TypeVar, cast - -# Define a type variable for the decorator -F = TypeVar("F", bound=Callable[..., Any]) - - -# Check which Pydantic version is installed -def get_pydantic_version() -> int: - """Determine if Pydantic v1 or v2 is installed.""" - try: - import pydantic - - version = pydantic.__version__ - return 1 if version.startswith("1.") else 2 - except (ImportError, AttributeError): - # Default to v1 if we can't determine version - return 1 - - -PYDANTIC_VERSION = get_pydantic_version() - - -# Create compatibility layer for dict-like methods -def compat_dict_method(v1_method: Optional[Callable] = None) -> Callable: - """ - Decorator to make a dict method work with both Pydantic v1 and v2. - - In v1: Uses the decorated method (typically dict) - In v2: Redirects to model_dump with appropriate parameter mapping - """ - - def decorator(func: F) -> F: - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - if PYDANTIC_VERSION >= 2: - # Map v1 parameters to v2 parameters - # exclude -> exclude - # exclude_unset -> exclude_unset - # exclude_defaults -> exclude_defaults - # exclude_none -> exclude_none - # by_alias -> by_alias - model_dump_kwargs = kwargs.copy() - - # Handle the 'exclude' parameter differently between versions - exclude = kwargs.get("exclude", set()) - if isinstance(exclude, (set, dict)): - model_dump_kwargs["exclude"] = exclude - - return self.model_dump(**model_dump_kwargs) - return func(self, *args, **kwargs) - - return cast(F, wrapper) - - # Allow use as both @compat_dict_method and @compat_dict_method() - if v1_method is None: - return decorator - return decorator(v1_method) diff --git a/metadata-ingestion/src/datahub/sdk/search_filters.py b/metadata-ingestion/src/datahub/sdk/search_filters.py index da69f7667e4d1d..2b3168b136a932 100644 --- a/metadata-ingestion/src/datahub/sdk/search_filters.py +++ b/metadata-ingestion/src/datahub/sdk/search_filters.py @@ -19,10 +19,6 @@ from pydantic import field_validator from datahub.configuration.common import ConfigModel -from datahub.configuration.pydantic_migration_helpers import ( - PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR, - PYDANTIC_VERSION_2, -) from datahub.ingestion.graph.client import flexible_entity_type_to_graphql from datahub.ingestion.graph.filters import ( FilterOperator, @@ -59,10 +55,7 @@ def dfs(self) -> Iterator[_BaseFilter]: def _field_discriminator(cls) -> str: if cls is _BaseFilter: raise ValueError("Cannot get discriminator for _BaseFilter") - if PYDANTIC_VERSION_2: - fields: dict = cls.model_fields # type: ignore - else: - fields = cls.__fields__ # type: ignore + fields: dict = cls.model_fields # type: ignore # Assumes that there's only one field name per filter. # If that's not the case, this method should be overridden. @@ -516,9 +509,8 @@ def _parse_and_like_filter(value: Any) -> Any: return value -if TYPE_CHECKING or not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR: - # The `not TYPE_CHECKING` bit is required to make the linter happy, - # since we currently only run mypy with pydantic v1. +if TYPE_CHECKING: + # Simple union for type checking (mypy) Filter = Union[ _And, _Or, @@ -535,12 +527,11 @@ def _parse_and_like_filter(value: Any) -> Any: _TagFilter, _CustomCondition, ] - - _And.update_forward_refs() - _Or.update_forward_refs() - _Not.update_forward_refs() else: - from pydantic import Discriminator, Tag + # Runtime union with validators + # Pydantic v2's "smart union" matching will automatically discriminate based on unique fields. + # Note: We could use explicit Discriminator/Tag (available in Pydantic 2.4+) for slightly + # better performance, but the simple union approach works well across all Pydantic v2 versions. def _parse_json_from_string(value: Any) -> Any: if isinstance(value, str): @@ -551,39 +542,22 @@ def _parse_json_from_string(value: Any) -> Any: else: return value - # TODO: Once we're fully on pydantic 2, we can use a RootModel here. - # That way we'd be able to attach methods to the Filter type. - # e.g. replace load_filters(...) with Filter.load(...) Filter = Annotated[ - Annotated[ - Union[ - Annotated[_And, Tag(_And._field_discriminator())], - Annotated[_Or, Tag(_Or._field_discriminator())], - Annotated[_Not, Tag(_Not._field_discriminator())], - Annotated[ - _EntityTypeFilter, Tag(_EntityTypeFilter._field_discriminator()) - ], - Annotated[ - _EntitySubtypeFilter, - Tag(_EntitySubtypeFilter._field_discriminator()), - ], - Annotated[_StatusFilter, Tag(_StatusFilter._field_discriminator())], - Annotated[_PlatformFilter, Tag(_PlatformFilter._field_discriminator())], - Annotated[_DomainFilter, Tag(_DomainFilter._field_discriminator())], - Annotated[ - _ContainerFilter, Tag(_ContainerFilter._field_discriminator()) - ], - Annotated[_EnvFilter, Tag(_EnvFilter._field_discriminator())], - Annotated[_OwnerFilter, Tag(_OwnerFilter._field_discriminator())], - Annotated[ - _GlossaryTermFilter, Tag(_GlossaryTermFilter._field_discriminator()) - ], - Annotated[_TagFilter, Tag(_TagFilter._field_discriminator())], - Annotated[ - _CustomCondition, Tag(_CustomCondition._field_discriminator()) - ], - ], - Discriminator(_filter_discriminator), + Union[ + _And, + _Or, + _Not, + _EntityTypeFilter, + _EntitySubtypeFilter, + _StatusFilter, + _PlatformFilter, + _DomainFilter, + _ContainerFilter, + _EnvFilter, + _OwnerFilter, + _GlossaryTermFilter, + _TagFilter, + _CustomCondition, ], pydantic.BeforeValidator(_parse_and_like_filter), pydantic.BeforeValidator(_parse_json_from_string), @@ -596,10 +570,7 @@ def _parse_json_from_string(value: Any) -> Any: def load_filters(obj: Any) -> Filter: - if PYDANTIC_VERSION_2: - return pydantic.TypeAdapter(Filter).validate_python(obj) # type: ignore - else: - return pydantic.TypeAdapter(Filter).validate_python(obj) # type: ignore + return pydantic.TypeAdapter(Filter).validate_python(obj) # type: ignore # We need FilterDsl for two reasons: diff --git a/metadata-ingestion/src/datahub/sql_parsing/_models.py b/metadata-ingestion/src/datahub/sql_parsing/_models.py index d586e7d6d9045b..044dbda8a583c4 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/_models.py +++ b/metadata-ingestion/src/datahub/sql_parsing/_models.py @@ -4,7 +4,6 @@ import sqlglot from pydantic import BaseModel -from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2 from datahub.metadata.schema_classes import SchemaFieldDataTypeClass @@ -16,17 +15,13 @@ class _ParserBaseModel( }, ): def json(self, *args: Any, **kwargs: Any) -> str: - if PYDANTIC_VERSION_2: - return super().model_dump_json(*args, **kwargs) # type: ignore - else: - return super().json(*args, **kwargs) + return super().model_dump_json(*args, **kwargs) # type: ignore @functools.total_ordering class _FrozenModel(_ParserBaseModel, frozen=True): def __lt__(self, other: "_FrozenModel") -> bool: - # TODO: The __fields__ attribute is deprecated in Pydantic v2. - for field in self.__fields__: + for field in self.model_fields: self_v = getattr(self, field) other_v = getattr(other, field) diff --git a/metadata-ingestion/src/datahub/utilities/lossy_collections.py b/metadata-ingestion/src/datahub/utilities/lossy_collections.py index 31d6d0eb842d04..06dac9b6ac42c1 100644 --- a/metadata-ingestion/src/datahub/utilities/lossy_collections.py +++ b/metadata-ingestion/src/datahub/utilities/lossy_collections.py @@ -1,8 +1,6 @@ import random from typing import Dict, Generic, Iterable, Iterator, List, Set, TypeVar, Union -from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2 - T = TypeVar("T") _KT = TypeVar("_KT") _VT = TypeVar("_VT") @@ -47,15 +45,11 @@ def __repr__(self) -> str: def __str__(self) -> str: return repr(self) - if PYDANTIC_VERSION_2: - # With pydantic 2, it doesn't recognize that this is a list subclass, - # so we need to make it explicit. - - @classmethod - def __get_pydantic_core_schema__(cls, source_type, handler): # type: ignore - from pydantic_core import core_schema + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): # type: ignore + from pydantic_core import core_schema - return core_schema.no_info_after_validator_function(cls, handler(list)) + return core_schema.no_info_after_validator_function(cls, handler(list)) def as_obj(self) -> List[Union[T, str]]: from datahub.ingestion.api.report import Report diff --git a/metadata-ingestion/tests/integration/dynamodb/test_dynamodb.py b/metadata-ingestion/tests/integration/dynamodb/test_dynamodb.py index 49e4823009d921..190e73ec9ed4dc 100644 --- a/metadata-ingestion/tests/integration/dynamodb/test_dynamodb.py +++ b/metadata-ingestion/tests/integration/dynamodb/test_dynamodb.py @@ -112,11 +112,11 @@ def test_dynamodb(pytestconfig, tmp_path): minimum_values_threshold=1, info_types_config={ "Phone_Number": InfoTypeConfig( - prediction_factors_and_weights=PredictionFactorsAndWeights( - name=0.7, - description=0, - datatype=0, - values=0.3, + Prediction_Factors_and_Weights=PredictionFactorsAndWeights( + Name=0.7, + Description=0, + Datatype=0, + Values=0.3, ) ) }, diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py index 7059bc3b58eb78..18964da45b32db 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py @@ -89,18 +89,18 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): confidence_level_threshold=0.58, info_types_config={ "Age": InfoTypeConfig( - prediction_factors_and_weights=PredictionFactorsAndWeights( - name=0, values=1, description=0, datatype=0 + Prediction_Factors_and_Weights=PredictionFactorsAndWeights( + Name=0, Values=1, Description=0, Datatype=0 ) ), "CloudRegion": InfoTypeConfig( - prediction_factors_and_weights=PredictionFactorsAndWeights( - name=0, - description=0, - datatype=0, - values=1, + Prediction_Factors_and_Weights=PredictionFactorsAndWeights( + Name=0, + Description=0, + Datatype=0, + Values=1, ), - values=ValuesFactorConfig( + Values=ValuesFactorConfig( prediction_type="regex", regex=[ r"(af|ap|ca|eu|me|sa|us)-(central|north|(north(?:east|west))|south|south(?:east|west)|east|west)-\d+" diff --git a/metadata-ingestion/tests/unit/cli/assertion/dmf_definitions.sql b/metadata-ingestion/tests/unit/cli/assertion/dmf_definitions.sql index 85056e150b9b33..917caa85acbeaf 100644 --- a/metadata-ingestion/tests/unit/cli/assertion/dmf_definitions.sql +++ b/metadata-ingestion/tests/unit/cli/assertion/dmf_definitions.sql @@ -7,7 +7,7 @@ COMMENT = 'Created via DataHub for assertion urn:li:assertion:025cce4dd4123c0f007908011a9c64d7 of type freshness' AS $$ - select case when metric <= 3600 then 1 else 0 end from (select timediff( + select case when metric <= 3600.0 then 1 else 0 end from (select timediff( second, max(col_timestamp::TIMESTAMP_LTZ), SNOWFLAKE.CORE.DATA_METRIC_SCHEDULED_TIME() diff --git a/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py b/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py index 8f72dc1178f256..62dd3b78302c83 100644 --- a/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py +++ b/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py @@ -7,9 +7,6 @@ import yaml from pydantic import ValidationError -from datahub.configuration.pydantic_migration_helpers import ( - PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR, -) from datahub.ingestion.graph.filters import ( RemovedStatusFilter, SearchFilterRule, @@ -312,37 +309,28 @@ def test_filter_discriminator() -> None: ) -@pytest.mark.skipif( - not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR, - reason="Tagged union w/ callable discriminator is not supported by the current pydantic version", -) def test_tagged_union_error_messages() -> None: - # With pydantic v1, we'd get 10+ validation errors and it'd be hard to - # understand what went wrong. With v2, we get a single simple error message. + # With pydantic v2, we get validation errors for each union member with pytest.raises( ValidationError, match=re.compile( - r"1 validation error.*entity_type\.entity_type.*Input should be a valid list", + r"validation error.*entity_type.*Input should be a valid list", re.DOTALL, ), ): load_filters({"entity_type": 6}) - # Even when within an "and" clause, we get a single error message. + # Without discriminators, we get verbose union errors for unknown fields with pytest.raises( ValidationError, match=re.compile( - r"1 validation error.*Input tag 'unknown_field' found using .+ does not match any of the expected tags:.+union_tag_invalid", + r"validation error.*unknown_field.*Extra inputs are not permitted", re.DOTALL, ), ): load_filters({"and": [{"unknown_field": 6}]}) -@pytest.mark.skipif( - not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR, - reason="Tagged union w/ callable discriminator is not supported by the current pydantic version", -) def test_filter_before_validators() -> None: # Test that we can load a filter from a string. # Sometimes we get filters encoded as JSON, and we want to handle those gracefully. @@ -355,7 +343,7 @@ def test_filter_before_validators() -> None: with pytest.raises( ValidationError, match=re.compile( - r"1 validation error.+Unable to extract tag using discriminator", re.DOTALL + r"validation error.*Input should be a valid dictionary", re.DOTALL ), ): load_filters("this is invalid json but should not raise a json error") @@ -379,7 +367,7 @@ def test_filter_before_validators() -> None: with pytest.raises( ValidationError, match=re.compile( - r"1 validation error.*container\.entity_type.*Extra inputs are not permitted.*", + r"validation error.*Extra inputs are not permitted.*", re.DOTALL, ), ): diff --git a/smoke-test/pyproject.toml b/smoke-test/pyproject.toml index c0a67cc3348d06..9987599d4d815d 100644 --- a/smoke-test/pyproject.toml +++ b/smoke-test/pyproject.toml @@ -67,7 +67,7 @@ ban-relative-imports = "all" "__init__.py" = ["F401"] [tool.mypy] -plugins = ["pydantic.mypy", "pydantic.v1.mypy"] +plugins = ["pydantic.mypy"] exclude = "^(venv/|build/|dist/)" ignore_missing_imports = true namespace_packages = false From 3a360ce235e07bbf481a83fc25ba1ac60f818ccb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20G=C3=B3mez=20Villamor?= Date: Tue, 11 Nov 2025 11:27:48 +0000 Subject: [PATCH 2/4] refactor: migrate ConfigModel to Pydantic v2 and consolidate config - Migrate ConfigModel from Pydantic v1 style (class Config) to v2 style (model_config = ConfigDict) - Replace 35 instances of manual 'model_config = {"extra": "forbid"}' with ConfigModel inheritance across 9 files - Simplify Filter union in search_filters.py by removing TYPE_CHECKING conditional - Fix redshift.py to use parse_obj_allow_extras() instead of direct Config.extra manipulation This eliminates code duplication and centralizes Pydantic configuration policy. --- .../api/entities/assertion/assertion.py | 11 ++- .../assertion/assertion_config_spec.py | 7 +- .../entities/assertion/assertion_operator.py | 69 +++++------------- .../entities/assertion/assertion_trigger.py | 18 ++--- .../api/entities/assertion/field_assertion.py | 7 +- .../datahub/api/entities/assertion/filter.py | 5 +- .../api/entities/datacontract/assertion.py | 6 +- .../datacontract/assertion_operator.py | 23 +++--- .../api/entities/datacontract/datacontract.py | 7 +- .../src/datahub/configuration/common.py | 35 ++++----- .../ingestion/source/redshift/redshift.py | 7 +- .../src/datahub/sdk/search_filters.py | 71 +++++++------------ 12 files changed, 95 insertions(+), 171 deletions(-) diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py index 9a934be91e9e77..ae1e8962454a2c 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py @@ -1,15 +1,14 @@ from abc import abstractmethod from typing import Optional -from pydantic import BaseModel, Field +from pydantic import Field from datahub.api.entities.assertion.assertion_trigger import AssertionTrigger +from datahub.configuration.common import ConfigModel from datahub.metadata.com.linkedin.pegasus2avro.assertion import AssertionInfo -class BaseAssertionProtocol(BaseModel): - model_config = {"extra": "forbid"} - +class BaseAssertionProtocol(ConfigModel): @abstractmethod def get_id(self) -> str: pass @@ -27,9 +26,7 @@ def get_assertion_trigger( pass -class BaseAssertion(BaseModel): - model_config = {"extra": "forbid"} - +class BaseAssertion(ConfigModel): id_raw: Optional[str] = Field( default=None, description="The raw id of the assertion." diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_config_spec.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_config_spec.py index d12d79a9fbbd91..8fb171aaf952f7 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_config_spec.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_config_spec.py @@ -1,13 +1,14 @@ from typing import List, Optional -from pydantic import BaseModel, Field +from pydantic import Field from ruamel.yaml import YAML from typing_extensions import Literal from datahub.api.entities.assertion.datahub_assertion import DataHubAssertion +from datahub.configuration.common import ConfigModel -class AssertionsConfigSpec(BaseModel): +class AssertionsConfigSpec(ConfigModel): """ Declarative configuration specification for datahub assertions. @@ -18,8 +19,6 @@ class AssertionsConfigSpec(BaseModel): In future, this would invoke datahub GraphQL API to upsert assertions. """ - model_config = {"extra": "forbid"} - version: Literal[1] id: Optional[str] = Field( diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py index 7e1259a22c0d6f..b6b57a89f13b9f 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py @@ -1,9 +1,9 @@ import json from typing import List, Optional, Union -from pydantic import BaseModel from typing_extensions import Literal, Protocol +from datahub.configuration.common import ConfigModel from datahub.metadata.schema_classes import ( AssertionStdOperatorClass, AssertionStdParameterClass, @@ -61,9 +61,7 @@ def _generate_assertion_std_parameters( ) -class EqualToOperator(BaseModel): - model_config = {"extra": "forbid"} - +class EqualToOperator(ConfigModel): type: Literal["equal_to"] value: Union[str, int, float] @@ -76,8 +74,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class NotEqualToOperator(BaseModel): - model_config = {"extra": "forbid"} +class NotEqualToOperator(ConfigModel): type: Literal["not_equal_to"] value: Union[str, int, float] @@ -90,9 +87,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class BetweenOperator(BaseModel): - model_config = {"extra": "forbid"} - +class BetweenOperator(ConfigModel): type: Literal["between"] min: Union[int, float] max: Union[int, float] @@ -108,9 +103,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: ) -class LessThanOperator(BaseModel): - model_config = {"extra": "forbid"} - +class LessThanOperator(ConfigModel): type: Literal["less_than"] value: Union[int, float] @@ -123,9 +116,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class GreaterThanOperator(BaseModel): - model_config = {"extra": "forbid"} - +class GreaterThanOperator(ConfigModel): type: Literal["greater_than"] value: Union[int, float] @@ -138,9 +129,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class LessThanOrEqualToOperator(BaseModel): - model_config = {"extra": "forbid"} - +class LessThanOrEqualToOperator(ConfigModel): type: Literal["less_than_or_equal_to"] value: Union[int, float] @@ -153,9 +142,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class GreaterThanOrEqualToOperator(BaseModel): - model_config = {"extra": "forbid"} - +class GreaterThanOrEqualToOperator(ConfigModel): type: Literal["greater_than_or_equal_to"] value: Union[int, float] @@ -168,9 +155,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class InOperator(BaseModel): - model_config = {"extra": "forbid"} - +class InOperator(ConfigModel): type: Literal["in"] value: List[Union[str, float, int]] @@ -183,9 +168,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class NotInOperator(BaseModel): - model_config = {"extra": "forbid"} - +class NotInOperator(ConfigModel): type: Literal["not_in"] value: List[Union[str, float, int]] @@ -198,9 +181,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class IsNullOperator(BaseModel): - model_config = {"extra": "forbid"} - +class IsNullOperator(ConfigModel): type: Literal["is_null"] operator: str = AssertionStdOperatorClass.NULL @@ -212,9 +193,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters() -class NotNullOperator(BaseModel): - model_config = {"extra": "forbid"} - +class NotNullOperator(ConfigModel): type: Literal["is_not_null"] operator: str = AssertionStdOperatorClass.NOT_NULL @@ -226,9 +205,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters() -class IsTrueOperator(BaseModel): - model_config = {"extra": "forbid"} - +class IsTrueOperator(ConfigModel): type: Literal["is_true"] operator: str = AssertionStdOperatorClass.IS_TRUE @@ -240,9 +217,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters() -class IsFalseOperator(BaseModel): - model_config = {"extra": "forbid"} - +class IsFalseOperator(ConfigModel): type: Literal["is_false"] operator: str = AssertionStdOperatorClass.IS_FALSE @@ -254,9 +229,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters() -class ContainsOperator(BaseModel): - model_config = {"extra": "forbid"} - +class ContainsOperator(ConfigModel): type: Literal["contains"] value: str @@ -269,9 +242,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class EndsWithOperator(BaseModel): - model_config = {"extra": "forbid"} - +class EndsWithOperator(ConfigModel): type: Literal["ends_with"] value: str @@ -284,9 +255,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class StartsWithOperator(BaseModel): - model_config = {"extra": "forbid"} - +class StartsWithOperator(ConfigModel): type: Literal["starts_with"] value: str @@ -299,9 +268,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class MatchesRegexOperator(BaseModel): - model_config = {"extra": "forbid"} - +class MatchesRegexOperator(ConfigModel): type: Literal["matches_regex"] value: str diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py index 5542cdab5f9dee..9d34581c7a235b 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py @@ -2,13 +2,13 @@ from typing import Union import humanfriendly -from pydantic import BaseModel, Field, RootModel, field_validator +from pydantic import Field, RootModel, field_validator from typing_extensions import Literal +from datahub.configuration.common import ConfigModel -class CronTrigger(BaseModel): - model_config = {"extra": "forbid"} +class CronTrigger(ConfigModel): type: Literal["cron"] cron: str = Field( description="The cron expression to use. See https://crontab.guru/ for help." @@ -19,9 +19,7 @@ class CronTrigger(BaseModel): ) -class IntervalTrigger(BaseModel): - model_config = {"extra": "forbid"} - +class IntervalTrigger(ConfigModel): type: Literal["interval"] interval: timedelta @@ -34,15 +32,11 @@ def lookback_interval_to_timedelta(cls, v): raise ValueError("Invalid value.") -class EntityChangeTrigger(BaseModel): - model_config = {"extra": "forbid"} - +class EntityChangeTrigger(ConfigModel): type: Literal["on_table_change"] -class ManualTrigger(BaseModel): - model_config = {"extra": "forbid"} - +class ManualTrigger(ConfigModel): type: Literal["manual"] diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py index 200d0149bd1dc1..5e2eaf130450a4 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Optional, Union -from pydantic import BaseModel, Field +from pydantic import Field from typing_extensions import Literal from datahub.api.entities.assertion.assertion import ( @@ -11,6 +11,7 @@ from datahub.api.entities.assertion.assertion_trigger import AssertionTrigger from datahub.api.entities.assertion.field_metric import FieldMetric from datahub.api.entities.assertion.filter import DatasetFilter +from datahub.configuration.common import ConfigModel from datahub.emitter.mce_builder import datahub_guid from datahub.metadata.com.linkedin.pegasus2avro.assertion import ( AssertionInfo, @@ -29,9 +30,7 @@ ) -class FieldValuesFailThreshold(BaseModel): - model_config = {"extra": "forbid"} - +class FieldValuesFailThreshold(ConfigModel): type: Literal["count", "percentage"] = Field(default="count") value: int = Field(default=0) diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/filter.py b/metadata-ingestion/src/datahub/api/entities/assertion/filter.py index faffc140eee5b4..b1830106a720bd 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/filter.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/filter.py @@ -1,10 +1,9 @@ -from pydantic import BaseModel from typing_extensions import Literal +from datahub.configuration.common import ConfigModel -class SqlFilter(BaseModel): - model_config = {"extra": "forbid"} +class SqlFilter(ConfigModel): type: Literal["sql"] sql: str diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion.py b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion.py index d135b63389c217..3022131396a241 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion.py @@ -1,9 +1,7 @@ from typing import Optional -from pydantic import BaseModel +from datahub.configuration.common import ConfigModel -class BaseAssertion(BaseModel): - model_config = {"extra": "forbid"} - +class BaseAssertion(ConfigModel): description: Optional[str] = None diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py index 3438f7242de24b..366140d975b094 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py @@ -1,8 +1,8 @@ from typing import Optional, Union -from pydantic import BaseModel from typing_extensions import Literal, Protocol +from datahub.configuration.common import ConfigModel from datahub.metadata.schema_classes import ( AssertionStdOperatorClass, AssertionStdParameterClass, @@ -56,8 +56,7 @@ def _generate_assertion_std_parameters( ) -class EqualToOperator(BaseModel): - model_config = {"extra": "forbid"} +class EqualToOperator(ConfigModel): type: Literal["equal_to"] value: Union[str, int, float] @@ -70,8 +69,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class BetweenOperator(BaseModel): - model_config = {"extra": "forbid"} +class BetweenOperator(ConfigModel): type: Literal["between"] min: Union[int, float] max: Union[int, float] @@ -87,8 +85,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: ) -class LessThanOperator(BaseModel): - model_config = {"extra": "forbid"} +class LessThanOperator(ConfigModel): type: Literal["less_than"] value: Union[int, float] @@ -101,8 +98,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class GreaterThanOperator(BaseModel): - model_config = {"extra": "forbid"} +class GreaterThanOperator(ConfigModel): type: Literal["greater_than"] value: Union[int, float] @@ -115,8 +111,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class LessThanOrEqualToOperator(BaseModel): - model_config = {"extra": "forbid"} +class LessThanOrEqualToOperator(ConfigModel): type: Literal["less_than_or_equal_to"] value: Union[int, float] @@ -129,8 +124,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class GreaterThanOrEqualToOperator(BaseModel): - model_config = {"extra": "forbid"} +class GreaterThanOrEqualToOperator(ConfigModel): type: Literal["greater_than_or_equal_to"] value: Union[int, float] @@ -143,8 +137,7 @@ def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters(value=self.value) -class NotNullOperator(BaseModel): - model_config = {"extra": "forbid"} +class NotNullOperator(ConfigModel): type: Literal["not_null"] operator: str = AssertionStdOperatorClass.NOT_NULL diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py b/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py index 61a93482b3a1ba..b38bdb12e1f6fc 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py @@ -1,7 +1,7 @@ import collections from typing import Dict, Iterable, List, Optional, Tuple, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import Field, field_validator from ruamel.yaml import YAML from typing_extensions import Literal @@ -11,6 +11,7 @@ ) from datahub.api.entities.datacontract.freshness_assertion import FreshnessAssertion from datahub.api.entities.datacontract.schema_assertion import SchemaAssertion +from datahub.configuration.common import ConfigModel from datahub.emitter.mce_builder import datahub_guid, make_assertion_urn from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.metadata.schema_classes import ( @@ -27,7 +28,7 @@ from datahub.utilities.urns.urn import guess_entity_type -class DataContract(BaseModel): +class DataContract(ConfigModel): """A yml representation of a Data Contract. This model is used as a simpler, Python-native representation of a DataHub data contract. @@ -35,8 +36,6 @@ class DataContract(BaseModel): that can be emitted to DataHub. """ - model_config = {"extra": "forbid"} - version: Literal[1] id: Optional[str] = Field( diff --git a/metadata-ingestion/src/datahub/configuration/common.py b/metadata-ingestion/src/datahub/configuration/common.py index 1c3a9127d0a2e6..ed15fc65c902bb 100644 --- a/metadata-ingestion/src/datahub/configuration/common.py +++ b/metadata-ingestion/src/datahub/configuration/common.py @@ -21,7 +21,7 @@ import pydantic import pydantic_core from cached_property import cached_property -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ConfigDict, ValidationError from pydantic.fields import Field from typing_extensions import Protocol, Self @@ -110,23 +110,24 @@ def __get_pydantic_json_schema__( return json_schema +def _config_model_schema_extra(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + # We use the custom "hidden_from_docs" attribute to hide fields from the + # autogenerated docs. + remove_fields = [] + for key, prop in schema.get("properties", {}).items(): + if prop.get("hidden_from_docs"): + remove_fields.append(key) + + for key in remove_fields: + del schema["properties"][key] + + class ConfigModel(BaseModel): - class Config: - @staticmethod - def _schema_extra(schema: Dict[str, Any], model: Type["ConfigModel"]) -> None: - # We use the custom "hidden_from_docs" attribute to hide fields from the - # autogenerated docs. - remove_fields = [] - for key, prop in schema.get("properties", {}).items(): - if prop.get("hidden_from_docs"): - remove_fields.append(key) - - for key in remove_fields: - del schema["properties"][key] - - extra = "forbid" - ignored_types = (cached_property,) - json_schema_extra = _schema_extra + model_config = ConfigDict( + extra="forbid", + ignored_types=(cached_property,), + json_schema_extra=_config_model_schema_extra, + ) @classmethod def parse_obj_allow_extras(cls, obj: Any) -> Self: diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py index 3a4ec6470b90c5..7f4b67224c5613 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py @@ -6,7 +6,6 @@ import humanfriendly # These imports verify that the dependencies are available. -import pydantic import redshift_connector from datahub.configuration.common import AllowDenyPattern @@ -233,10 +232,8 @@ def get_platform_instance_id(self) -> str: def test_connection(config_dict: dict) -> TestConnectionReport: test_report = TestConnectionReport() try: - RedshiftConfig.Config.extra = ( - pydantic.Extra.allow - ) # we are okay with extra fields during this stage - config = RedshiftConfig.model_validate(config_dict) + # We are okay with extra fields during this stage + config = RedshiftConfig.parse_obj_allow_extras(config_dict) # source = RedshiftSource(config, report) connection: redshift_connector.Connection = ( RedshiftSource.get_redshift_connection(config) diff --git a/metadata-ingestion/src/datahub/sdk/search_filters.py b/metadata-ingestion/src/datahub/sdk/search_filters.py index 2b3168b136a932..60a9873c12ba90 100644 --- a/metadata-ingestion/src/datahub/sdk/search_filters.py +++ b/metadata-ingestion/src/datahub/sdk/search_filters.py @@ -3,7 +3,6 @@ import abc import json from typing import ( - TYPE_CHECKING, Annotated, Any, ClassVar, @@ -509,9 +508,23 @@ def _parse_and_like_filter(value: Any) -> Any: return value -if TYPE_CHECKING: - # Simple union for type checking (mypy) - Filter = Union[ +# Pydantic v2's "smart union" matching will automatically discriminate based on unique fields. +# Note: We could use explicit Discriminator/Tag (available in Pydantic 2.4+) for slightly +# better performance, but the simple union approach works well across all Pydantic v2 versions. + + +def _parse_json_from_string(value: Any) -> Any: + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + return value + else: + return value + + +Filter = Annotated[ + Union[ _And, _Or, _Not, @@ -526,47 +539,15 @@ def _parse_and_like_filter(value: Any) -> Any: _GlossaryTermFilter, _TagFilter, _CustomCondition, - ] -else: - # Runtime union with validators - # Pydantic v2's "smart union" matching will automatically discriminate based on unique fields. - # Note: We could use explicit Discriminator/Tag (available in Pydantic 2.4+) for slightly - # better performance, but the simple union approach works well across all Pydantic v2 versions. - - def _parse_json_from_string(value: Any) -> Any: - if isinstance(value, str): - try: - return json.loads(value) - except json.JSONDecodeError: - return value - else: - return value - - Filter = Annotated[ - Union[ - _And, - _Or, - _Not, - _EntityTypeFilter, - _EntitySubtypeFilter, - _StatusFilter, - _PlatformFilter, - _DomainFilter, - _ContainerFilter, - _EnvFilter, - _OwnerFilter, - _GlossaryTermFilter, - _TagFilter, - _CustomCondition, - ], - pydantic.BeforeValidator(_parse_and_like_filter), - pydantic.BeforeValidator(_parse_json_from_string), - ] - - # Required to resolve forward references to "Filter" - _And.model_rebuild() # type: ignore - _Or.model_rebuild() # type: ignore - _Not.model_rebuild() # type: ignore + ], + pydantic.BeforeValidator(_parse_and_like_filter), + pydantic.BeforeValidator(_parse_json_from_string), +] + +# Required to resolve forward references to "Filter" +_And.model_rebuild() # type: ignore +_Or.model_rebuild() # type: ignore +_Not.model_rebuild() # type: ignore def load_filters(obj: Any) -> Filter: From 30532e7058081ce21debdb19b98cbe62b134fefd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20G=C3=B3mez=20Villamor?= Date: Tue, 11 Nov 2025 18:08:57 +0000 Subject: [PATCH 3/4] refactor(pydantic): complete migration from Pydantic v1 to v2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit completes the migration of all Pydantic v1 legacy syntax to v2 across the entire DataHub Python codebase. **Configuration Migration (18 instances):** - Migrate `class Config:` → `model_config = ConfigDict(...)` - Updated ConfigModel, PermissiveConfigModel, ConnectionModel base classes - Migrated 13 additional model classes across multiple files **Method Migrations (76 instances):** - `.parse_obj()` → `.model_validate()` (38 instances) - `.parse_raw()` → `.model_validate_json()` (2 instances) - `.dict()` → `.model_dump()` (27 instances) - `.json()` → `.model_dump_json()` (4 instances) - `.update_forward_refs()` → `.model_rebuild()` (3 instances) - `.copy()` and `.schema()` - all false positives (dicts/lists/HTTP responses) **Scope of Changes:** - metadata-ingestion/src: 50 instances - metadata-ingestion/tests: 30 instances - datahub-actions: 12 instances - smoke-test: 2 instances **Total: 94 Pydantic v1 calls migrated** **Key Files Updated:** - common.py: Base ConfigModel classes - Multiple source files: sql_queries.py, datahub_classifier.py, schema_assertion.py, etc. - Multiple CLI files: structuredproperties_cli.py, forms_cli.py, dataset_cli.py, etc. - Test files: RDS IAM tests, Unity Catalog tests, assertion tests, etc. - datahub-actions: propagation_action.py, filter tests, consumer offsets - smoke-test: stateful ingestion tests All Pydantic v2 deprecation warnings have been resolved. The codebase is now fully compliant with Pydantic v2 with no remaining v1 syntax. --- .../propagation/docs/propagation_action.py | 2 +- ...hub_cloud_events_consumer_offsets_store.py | 4 ++-- .../filter/test_filter_transformer.py | 18 ++++++++--------- .../api/entities/common/serialized_value.py | 5 ++--- .../entities/datacontract/schema_assertion.py | 5 ++--- .../entities/external/external_entities.py | 18 ++++++++--------- metadata-ingestion/src/datahub/cli/migrate.py | 2 +- .../datahub/cli/specific/dataproduct_cli.py | 2 +- .../src/datahub/cli/specific/dataset_cli.py | 2 +- .../src/datahub/cli/specific/forms_cli.py | 2 +- .../cli/specific/structuredproperties_cli.py | 10 +++++----- .../src/datahub/configuration/common.py | 6 ++---- .../src/datahub/emitter/mcp_builder.py | 6 +++--- .../ingestion/glossary/datahub_classifier.py | 8 +++----- .../src/datahub/ingestion/graph/client.py | 4 ++-- .../src/datahub/ingestion/graph/config.py | 5 +++-- .../datahub/ingestion/sink/datahub_lite.py | 2 +- .../source/data_lake_common/path_spec.py | 5 ++--- .../datahub/ingestion/source/dbt/dbt_core.py | 5 ++--- .../ingestion/source/delta_lake/source.py | 2 +- .../src/datahub/ingestion/source/hex/api.py | 11 ++++------ .../ingestion/source/metadata/lineage.py | 2 +- .../report_server_domain.py | 9 ++++----- .../source/snowflake/snowflake_usage_v2.py | 4 ++-- .../ingestion/source/sql/sql_common.py | 2 +- .../datahub/ingestion/source/sql_queries.py | 5 ++--- .../src/datahub/ingestion/source/superset.py | 8 ++++---- .../transformer/add_dataset_properties.py | 5 +++-- .../ingestion/transformer/dataset_domain.py | 2 +- .../datahub/secret/datahub_secret_store.py | 5 ++--- .../testing/check_sql_parser_result.py | 2 +- .../test_data_quality_assertion.py | 2 +- .../state/test_sql_common_state.py | 2 +- .../tests/unit/test_athena_source.py | 2 +- .../tests/unit/test_mysql_rds_iam.py | 16 +++++++-------- .../tests/unit/test_postgres_rds_iam.py | 18 ++++++++--------- .../tests/unit/test_unity_catalog_source.py | 20 +++++++++---------- smoke-test/tests/test_stateful_ingestion.py | 4 ++-- 38 files changed, 110 insertions(+), 122 deletions(-) diff --git a/datahub-actions/src/datahub_actions/plugin/action/propagation/docs/propagation_action.py b/datahub-actions/src/datahub_actions/plugin/action/propagation/docs/propagation_action.py index 5f1b42433ad83b..6f2ea9cd420709 100644 --- a/datahub-actions/src/datahub_actions/plugin/action/propagation/docs/propagation_action.py +++ b/datahub-actions/src/datahub_actions/plugin/action/propagation/docs/propagation_action.py @@ -280,7 +280,7 @@ def process_schema_field_documentation( if current_documentation_instance.attribution else {} ) - source_details_parsed: SourceDetails = SourceDetails.parse_obj( + source_details_parsed: SourceDetails = SourceDetails.model_validate( source_details ) should_stop_propagation, reason = self.should_stop_propagation( diff --git a/datahub-actions/src/datahub_actions/plugin/source/acryl/datahub_cloud_events_consumer_offsets_store.py b/datahub-actions/src/datahub_actions/plugin/source/acryl/datahub_cloud_events_consumer_offsets_store.py index 4be5e643344015..b7afd78f974ac9 100644 --- a/datahub-actions/src/datahub_actions/plugin/source/acryl/datahub_cloud_events_consumer_offsets_store.py +++ b/datahub-actions/src/datahub_actions/plugin/source/acryl/datahub_cloud_events_consumer_offsets_store.py @@ -46,11 +46,11 @@ def from_serialized_value(cls, value: SerializedValueClass) -> "EventConsumerSta return cls() def to_blob(self) -> bytes: - return self.json().encode() + return self.model_dump_json().encode() @staticmethod def from_blob(blob: bytes) -> "EventConsumerState": - return EventConsumerState.parse_raw(blob.decode()) + return EventConsumerState.model_validate_json(blob.decode()) class DataHubEventsConsumerPlatformResourceOffsetsStore: diff --git a/datahub-actions/tests/unit/plugin/transform/filter/test_filter_transformer.py b/datahub-actions/tests/unit/plugin/transform/filter/test_filter_transformer.py index 814a6f3d8d10ee..748d2b7a492b47 100644 --- a/datahub-actions/tests/unit/plugin/transform/filter/test_filter_transformer.py +++ b/datahub-actions/tests/unit/plugin/transform/filter/test_filter_transformer.py @@ -51,7 +51,7 @@ def as_json(self) -> str: def test_returns_none_when_diff_event_type(): - filter_transformer_config = FilterTransformerConfig.parse_obj( + filter_transformer_config = FilterTransformerConfig.model_validate( {"event_type": "EntityChangeEvent_v1", "event": {"field1": "a", "field2": "b"}} ) filter_transformer = FilterTransformer(filter_transformer_config) @@ -68,7 +68,7 @@ def test_returns_none_when_diff_event_type(): def test_does_exact_match(): - filter_transformer_config = FilterTransformerConfig.parse_obj( + filter_transformer_config = FilterTransformerConfig.model_validate( {"event_type": "EntityChangeEvent_v1", "event": {"field1": "a", "field2": "b"}} ) filter_transformer = FilterTransformer(filter_transformer_config) @@ -83,7 +83,7 @@ def test_does_exact_match(): def test_returns_none_when_no_match(): - filter_transformer_config = FilterTransformerConfig.parse_obj( + filter_transformer_config = FilterTransformerConfig.model_validate( {"event_type": "EntityChangeEvent_v1", "event": {"field1": "a", "field2": "b"}} ) filter_transformer = FilterTransformer(filter_transformer_config) @@ -97,7 +97,7 @@ def test_returns_none_when_no_match(): def test_matches_on_nested_event(): - filter_transformer_config = FilterTransformerConfig.parse_obj( + filter_transformer_config = FilterTransformerConfig.model_validate( { "event_type": "EntityChangeEvent_v1", "event": {"field1": {"nested_1": {"nested_b": "a"}}}, @@ -112,7 +112,7 @@ def test_matches_on_nested_event(): def test_returns_none_when_no_match_nested_event(): - filter_transformer_config = FilterTransformerConfig.parse_obj( + filter_transformer_config = FilterTransformerConfig.model_validate( { "event_type": "EntityChangeEvent_v1", "event": {"field1": {"nested_1": {"nested_b": "a"}}}, @@ -127,7 +127,7 @@ def test_returns_none_when_no_match_nested_event(): def test_returns_none_when_different_data_type(): - filter_transformer_config = FilterTransformerConfig.parse_obj( + filter_transformer_config = FilterTransformerConfig.model_validate( { "event_type": "EntityChangeEvent_v1", "event": {"field1": {"nested_1": {"nested_b": "a"}}}, @@ -142,7 +142,7 @@ def test_returns_none_when_different_data_type(): def test_returns_match_when_either_is_present(): - filter_transformer_config = FilterTransformerConfig.parse_obj( + filter_transformer_config = FilterTransformerConfig.model_validate( { "event_type": "EntityChangeEvent_v1", "event": {"field1": {"nested_1": ["a", "b"]}}, @@ -157,7 +157,7 @@ def test_returns_match_when_either_is_present(): def test_returns_none_when_neither_is_present(): - filter_transformer_config = FilterTransformerConfig.parse_obj( + filter_transformer_config = FilterTransformerConfig.model_validate( { "event_type": "EntityChangeEvent_v1", "event": {"field1": {"nested_1": ["a", "b"]}}, @@ -172,7 +172,7 @@ def test_returns_none_when_neither_is_present(): def test_no_match_when_list_filter_on_dict_obj(): - filter_transformer_config = FilterTransformerConfig.parse_obj( + filter_transformer_config = FilterTransformerConfig.model_validate( { "event_type": "EntityChangeEvent_v1", "event": {"field1": {"nested_1": ["a", "b"]}}, diff --git a/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py b/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py index 6eb2025a824328..c4f1519d72865f 100644 --- a/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py +++ b/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py @@ -3,7 +3,7 @@ from typing import Dict, Optional, Type, TypeVar, Union from avrogen.dict_wrapper import DictWrapper -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict import datahub.metadata.schema_classes as models from datahub.metadata.schema_classes import __SCHEMA_TYPES as SCHEMA_TYPES @@ -17,8 +17,7 @@ class SerializedResourceValue(BaseModel): - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) content_type: str blob: bytes diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py b/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py index 878e3dbc07e343..b59a264149aa69 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py @@ -3,7 +3,7 @@ import json from typing import List, Union -from pydantic import Field, RootModel +from pydantic import ConfigDict, Field, RootModel from typing_extensions import Literal from datahub.api.entities.datacontract.assertion import BaseAssertion @@ -36,8 +36,7 @@ def model_post_init(self, __context: object) -> None: class FieldListSchemaContract(BaseAssertion): - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) type: Literal["field-list"] diff --git a/metadata-ingestion/src/datahub/api/entities/external/external_entities.py b/metadata-ingestion/src/datahub/api/entities/external/external_entities.py index b0c6a81ffe001b..450ac9bae8d305 100644 --- a/metadata-ingestion/src/datahub/api/entities/external/external_entities.py +++ b/metadata-ingestion/src/datahub/api/entities/external/external_entities.py @@ -137,19 +137,19 @@ def create(self, platform_resource: PlatformResource) -> None: self.entity_class ) ) - entity = self.entity_class(**entity_obj.dict()) + entity = self.entity_class(**entity_obj.model_dump()) # Create updated entity ID with persisted=True entity_id = entity.get_id() - if hasattr(entity_id, "dict"): - entity_id_data = entity_id.dict() + if hasattr(entity_id, "model_dump"): + entity_id_data = entity_id.model_dump() entity_id_data["persisted"] = True # Create new entity ID with updated flags updated_entity_id = type(entity_id)(**entity_id_data) # Update the entity with the new ID (immutable update) - entity_data = entity.dict() # type: ignore[attr-defined] + entity_data = entity.model_dump() # type: ignore[attr-defined] entity_data["id"] = updated_entity_id updated_entity = type(entity)(**entity_data) @@ -359,13 +359,13 @@ def search_entity_by_urn(self, urn: str) -> Optional[TExternalEntityId]: self.entity_class ) ) - entity = self.entity_class(**entity_obj.dict()) + entity = self.entity_class(**entity_obj.model_dump()) # Check if platform instance matches entity_id = entity.get_id() if entity_id.platform_instance == self.platform_instance: # Create a new entity ID with the correct state instead of mutating - # All our entity IDs are Pydantic models, so we can use dict() method - entity_data = entity_id.dict() + # All our entity IDs are Pydantic models, so we can use model_dump() method + entity_data = entity_id.model_dump() entity_data["persisted"] = ( True # This entity was found in DataHub ) @@ -433,7 +433,7 @@ def get_entity_from_datahub( entity_obj = platform_resource.resource_info.value.as_pydantic_object( self.entity_class ) - result = self.entity_class(**entity_obj.dict()) + result = self.entity_class(**entity_obj.model_dump()) elif len(platform_resources) > 1: # Handle multiple matches - find the one with matching platform instance target_platform_instance = entity_id.platform_instance @@ -447,7 +447,7 @@ def get_entity_from_datahub( self.entity_class ) ) - entity = self.entity_class(**entity_obj.dict()) + entity = self.entity_class(**entity_obj.model_dump()) if entity.get_id().platform_instance == target_platform_instance: result = entity break diff --git a/metadata-ingestion/src/datahub/cli/migrate.py b/metadata-ingestion/src/datahub/cli/migrate.py index 2e8b3197e63868..0068edbd715495 100644 --- a/metadata-ingestion/src/datahub/cli/migrate.py +++ b/metadata-ingestion/src/datahub/cli/migrate.py @@ -356,7 +356,7 @@ def migrate_containers( if mcp.aspectName == "containerProperties": assert isinstance(mcp.aspect, ContainerPropertiesClass) containerProperties: ContainerPropertiesClass = mcp.aspect - containerProperties.customProperties = newKey.dict( + containerProperties.customProperties = newKey.model_dump( by_alias=True, exclude_none=True ) mcp.aspect = containerProperties diff --git a/metadata-ingestion/src/datahub/cli/specific/dataproduct_cli.py b/metadata-ingestion/src/datahub/cli/specific/dataproduct_cli.py index 29f0a6e47f2ed7..3a8bdf4e444a46 100644 --- a/metadata-ingestion/src/datahub/cli/specific/dataproduct_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/dataproduct_cli.py @@ -245,7 +245,7 @@ def get(urn: str, to_file: str) -> None: if graph.exists(urn): dataproduct: DataProduct = DataProduct.from_datahub(graph=graph, id=urn) click.secho( - f"{json.dumps(dataproduct.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(dataproduct.model_dump(exclude_unset=True, exclude_none=True), indent=2)}" ) if to_file: dataproduct.to_yaml(Path(to_file)) diff --git a/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py b/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py index 001d0ce5c084ee..a79fa574017dc2 100644 --- a/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py @@ -55,7 +55,7 @@ def get(urn: str, to_file: str) -> None: if graph.exists(urn): dataset: Dataset = Dataset.from_datahub(graph=graph, urn=urn) click.secho( - f"{json.dumps(dataset.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(dataset.model_dump(exclude_unset=True, exclude_none=True), indent=2)}" ) if to_file: dataset.to_yaml(Path(to_file)) diff --git a/metadata-ingestion/src/datahub/cli/specific/forms_cli.py b/metadata-ingestion/src/datahub/cli/specific/forms_cli.py index 3049643144a607..4ebca6608959bd 100644 --- a/metadata-ingestion/src/datahub/cli/specific/forms_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/forms_cli.py @@ -41,7 +41,7 @@ def get(urn: str, to_file: str) -> None: if graph.exists(urn): form: Forms = Forms.from_datahub(graph=graph, urn=urn) click.secho( - f"{json.dumps(form.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(form.model_dump(exclude_unset=True, exclude_none=True), indent=2)}" ) if to_file: form.to_yaml(Path(to_file)) diff --git a/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py b/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py index 4ee0b2b8163c37..33621490cc2940 100644 --- a/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py @@ -52,7 +52,7 @@ def get(urn: str, to_file: str) -> None: StructuredProperties.from_datahub(graph=graph, urn=urn) ) click.secho( - f"{json.dumps(structuredproperties.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(structuredproperties.model_dump(exclude_unset=True, exclude_none=True), indent=2)}" ) if to_file: structuredproperties.to_yaml(Path(to_file)) @@ -97,19 +97,19 @@ def to_yaml_list( # breakpoint() if existing_urn in {obj.urn for obj in objects}: existing_objects[i] = next( - obj.dict(exclude_unset=True, exclude_none=True) + obj.model_dump(exclude_unset=True, exclude_none=True) for obj in objects if obj.urn == existing_urn ) new_objects = [ - obj.dict(exclude_unset=True, exclude_none=True) + obj.model_dump(exclude_unset=True, exclude_none=True) for obj in objects if obj.urn not in existing_urns ] serialized_objects = existing_objects + new_objects else: serialized_objects = [ - obj.dict(exclude_unset=True, exclude_none=True) for obj in objects + obj.model_dump(exclude_unset=True, exclude_none=True) for obj in objects ] with open(file, "w") as fp: @@ -126,7 +126,7 @@ def to_yaml_list( else: for structuredproperty in structuredproperties: click.secho( - f"{json.dumps(structuredproperty.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(structuredproperty.model_dump(exclude_unset=True, exclude_none=True), indent=2)}" ) else: logger.info( diff --git a/metadata-ingestion/src/datahub/configuration/common.py b/metadata-ingestion/src/datahub/configuration/common.py index ed15fc65c902bb..ecea1bc1d9ac51 100644 --- a/metadata-ingestion/src/datahub/configuration/common.py +++ b/metadata-ingestion/src/datahub/configuration/common.py @@ -155,15 +155,13 @@ class PermissiveConfigModel(ConfigModel): # but still allow the user to pass in arbitrary fields that we don't care about. # It is usually used for argument bags that are passed through to third-party libraries. - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class ConnectionModel(BaseModel): """Represents the config associated with a connection""" - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class TransformerSemantics(ConfigEnum): diff --git a/metadata-ingestion/src/datahub/emitter/mcp_builder.py b/metadata-ingestion/src/datahub/emitter/mcp_builder.py index 5eb55be1a4ad80..882e5a89de8f13 100644 --- a/metadata-ingestion/src/datahub/emitter/mcp_builder.py +++ b/metadata-ingestion/src/datahub/emitter/mcp_builder.py @@ -51,7 +51,7 @@ class DatahubKey(BaseModel): def guid_dict(self) -> Dict[str, str]: - return self.dict(by_alias=True, exclude_none=True) + return self.model_dump(by_alias=True, exclude_none=True) def guid(self) -> str: bag = self.guid_dict() @@ -73,7 +73,7 @@ class ContainerKey(DatahubKey): backcompat_env_as_instance: bool = Field(default=False, exclude=True) def guid_dict(self) -> Dict[str, str]: - bag = self.dict(by_alias=True, exclude_none=True, exclude={"env"}) + bag = self.model_dump(by_alias=True, exclude_none=True, exclude={"env"}) if ( self.backcompat_env_as_instance @@ -85,7 +85,7 @@ def guid_dict(self) -> Dict[str, str]: return bag def property_dict(self) -> Dict[str, str]: - return self.dict(by_alias=True, exclude_none=True) + return self.model_dump(by_alias=True, exclude_none=True) def as_urn_typed(self) -> ContainerUrn: return ContainerUrn.from_string(self.as_urn()) diff --git a/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py b/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py index b6d68c347331ae..c25b2ccefc0d0c 100644 --- a/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py +++ b/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py @@ -3,7 +3,7 @@ from datahub_classify.helper_classes import ColumnInfo from datahub_classify.infotype_predictor import predict_infotypes from datahub_classify.reference_input import input1 as default_config -from pydantic import field_validator +from pydantic import ConfigDict, field_validator from pydantic.fields import Field from datahub.configuration.common import ConfigModel @@ -49,8 +49,7 @@ class ValuesFactorConfig(ConfigModel): class PredictionFactorsAndWeights(ConfigModel): - class Config: - populate_by_name = True + model_config = ConfigDict(populate_by_name=True) Name: float = Field(alias="name") Description: float = Field(alias="description") @@ -59,8 +58,7 @@ class Config: class InfoTypeConfig(ConfigModel): - class Config: - populate_by_name = True + model_config = ConfigDict(populate_by_name=True) Prediction_Factors_and_Weights: PredictionFactorsAndWeights = Field( description="Factors and their weights to consider when predicting info types", diff --git a/metadata-ingestion/src/datahub/ingestion/graph/client.py b/metadata-ingestion/src/datahub/ingestion/graph/client.py index fd327af6d3488e..d0b5b34dab7be3 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/client.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/client.py @@ -291,7 +291,7 @@ def _make_rest_sink_config( # TODO: We should refactor out the multithreading functionality of the sink # into a separate class that can be used by both the sink and the graph client # e.g. a DatahubBulkRestEmitter that both the sink and the graph client use. - return DatahubRestSinkConfig(**self.config.dict(), **(extra_config or {})) + return DatahubRestSinkConfig(**self.config.model_dump(), **(extra_config or {})) @contextlib.contextmanager def make_rest_sink( @@ -775,7 +775,7 @@ def set_connection_json( """ if isinstance(config, (ConfigModel, BaseModel)): - blob = config.json() + blob = config.model_dump_json() else: blob = json.dumps(config) diff --git a/metadata-ingestion/src/datahub/ingestion/graph/config.py b/metadata-ingestion/src/datahub/ingestion/graph/config.py index 02eee06d39668e..d0dec4fa3d828b 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/config.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/config.py @@ -1,6 +1,8 @@ from enum import Enum, auto from typing import Dict, List, Optional +from pydantic import ConfigDict + from datahub.configuration.common import ConfigModel from datahub.configuration.env_vars import get_datahub_component @@ -31,5 +33,4 @@ class DatahubClientConfig(ConfigModel): datahub_component: Optional[str] = None server_config_refresh_interval: Optional[int] = None - class Config: - extra = "ignore" + model_config = ConfigDict(extra="ignore") diff --git a/metadata-ingestion/src/datahub/ingestion/sink/datahub_lite.py b/metadata-ingestion/src/datahub/ingestion/sink/datahub_lite.py index 5f5c70259eff6c..87d72098ae3c60 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/datahub_lite.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/datahub_lite.py @@ -21,7 +21,7 @@ class DataHubLiteSinkConfig(LiteLocalConfig): class DataHubLiteSink(Sink[DataHubLiteSinkConfig, SinkReport]): def __post_init__(self) -> None: - self.datahub_lite = get_datahub_lite(self.config.dict()) + self.datahub_lite = get_datahub_lite(self.config.model_dump()) def write_record_async( self, diff --git a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py index 5809887e15041f..4cc505d3e31b9b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py +++ b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py @@ -7,7 +7,7 @@ import parse from cached_property import cached_property -from pydantic import field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from pydantic.fields import Field from wcmatch import pathlib @@ -83,8 +83,7 @@ class FolderTraversalMethod(Enum): class PathSpec(ConfigModel): - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) include: str = Field( description="Path to table. Name variable `{table}` is used to mark the folder with dataset. In absence of `{table}`, file level dataset will be created. Check below examples for more details.", diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py index 9d7d72640b183a..d2a794fb02e571 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py @@ -9,7 +9,7 @@ import dateutil.parser import requests from packaging import version -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from datahub.configuration.git import GitReference from datahub.configuration.validate_field_rename import pydantic_renamed_field @@ -341,8 +341,7 @@ class DBTRunTiming(BaseModel): class DBTRunResult(BaseModel): - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") status: str timing: List[DBTRunTiming] = [] diff --git a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/source.py b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/source.py index 1139663a429268..5d442057f07c29 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/source.py @@ -128,7 +128,7 @@ def __init__(self, config: DeltaLakeSourceConfig, ctx: PipelineContext): # self.profiling_times_taken = [] config_report = { - config_option: config.dict().get(config_option) + config_option: config.model_dump().get(config_option) for config_option in config_options_to_report } diff --git a/metadata-ingestion/src/datahub/ingestion/source/hex/api.py b/metadata-ingestion/src/datahub/ingestion/source/hex/api.py index 58ba64b75252a7..6b22163c4f803e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/hex/api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/hex/api.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Union import requests -from pydantic import BaseModel, Field, ValidationError, field_validator +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator from requests.adapters import HTTPAdapter from typing_extensions import assert_never from urllib3.util.retry import Retry @@ -133,8 +133,7 @@ class HexApiSharing(BaseModel): collections: Optional[List[HexApiCollectionAccess]] = [] groups: Optional[List[Any]] = [] - class Config: - extra = "ignore" # Allow extra fields in the JSON + model_config = ConfigDict(extra="ignore") # Allow extra fields in the JSON class HexApiItemType(StrEnum): @@ -165,8 +164,7 @@ class HexApiProjectApiResource(BaseModel): schedules: Optional[List[HexApiSchedule]] = [] sharing: Optional[HexApiSharing] = None - class Config: - extra = "ignore" # Allow extra fields in the JSON + model_config = ConfigDict(extra="ignore") # Allow extra fields in the JSON @field_validator( "created_at", @@ -200,8 +198,7 @@ class HexApiProjectsListResponse(BaseModel): values: List[HexApiProjectApiResource] pagination: Optional[HexApiPageCursors] = None - class Config: - extra = "ignore" # Allow extra fields in the JSON + model_config = ConfigDict(extra="ignore") # Allow extra fields in the JSON @dataclass diff --git a/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py index b9cac84942067b..c63ddac72624ba 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py @@ -114,7 +114,7 @@ class EntityNodeConfig(ConfigModel): # https://pydantic-docs.helpmanual.io/usage/postponed_annotations/ required for when you reference a model within itself -EntityNodeConfig.update_forward_refs() +EntityNodeConfig.model_rebuild() class LineageFileSourceConfig(ConfigModel): diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py index 9ab5f3a2ba41a0..518c317b535cdc 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from datahub.ingestion.source.powerbi_report_server.constants import ( RelationshipDirection, @@ -357,9 +357,8 @@ class OwnershipData(BaseModel): existing_owners: Optional[List[OwnerClass]] = [] owner_to_add: Optional[CorpUser] = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) -CatalogItem.update_forward_refs() -CorpUserProperties.update_forward_refs() +CatalogItem.model_rebuild() +CorpUserProperties.model_rebuild() diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py index 86d2d03af94660..f85a7e538db535 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple import pydantic +from pydantic import ConfigDict from datahub.configuration.time_window_config import BaseTimeWindowConfig from datahub.emitter.mce_builder import make_user_urn @@ -70,8 +71,7 @@ class PermissiveModel(pydantic.BaseModel): - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class SnowflakeColumnReference(PermissiveModel): diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py index d05b57cb7ac75c..e07c7c56a2598e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py @@ -340,7 +340,7 @@ def __init__(self, config: SQLCommonConfig, ctx: PipelineContext, platform: str) self.classification_handler = ClassificationHandler(self.config, self.report) config_report = { - config_option: config.dict().get(config_option) + config_option: config.model_dump().get(config_option) for config_option in config_options_to_report } diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py index ef588e713f9951..1b776778746855 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py @@ -8,7 +8,7 @@ from typing import Any, ClassVar, Iterable, List, Optional, Union, cast import smart_open -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from datahub.configuration.common import HiddenFromDocs from datahub.configuration.datetimes import parse_user_datetime @@ -447,8 +447,7 @@ class QueryEntry(BaseModel): # Validation context for URN creation _validation_context: ClassVar[Optional[SqlQueriesSourceConfig]] = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) @field_validator("timestamp", mode="before") @classmethod diff --git a/metadata-ingestion/src/datahub/ingestion/source/superset.py b/metadata-ingestion/src/datahub/ingestion/source/superset.py index d7a90e7a645ca5..874b966c7a5d7a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/superset.py +++ b/metadata-ingestion/src/datahub/ingestion/source/superset.py @@ -9,7 +9,7 @@ import dateutil.parser as dp import requests import sqlglot -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, field_validator, model_validator from pydantic.fields import Field from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry @@ -242,9 +242,9 @@ class SupersetConfig( description="Can be used to change mapping for database names in superset to what you have in datahub", ) - class Config: - # This is required to allow preset configs to get parsed - extra = "allow" + model_config = ConfigDict( + extra="allow" # This is required to allow preset configs to get parsed + ) @field_validator("connect_uri", "display_uri", mode="after") @classmethod diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py index 0e406e0d061ee8..1b2919b728799c 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Type, cast +from pydantic import ConfigDict + from datahub.configuration.common import ( TransformerSemantics, TransformerSemanticsConfigModel, @@ -25,8 +27,7 @@ def get_properties_to_add(self, entity_urn: str) -> Dict[str, str]: class AddDatasetPropertiesConfig(TransformerSemanticsConfigModel): add_properties_resolver_class: Type[AddDatasetPropertiesResolverBase] - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) _resolve_properties_class = pydantic_resolve_key("add_properties_resolver_class") diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py index 108185bb2a25ce..dfcea7028d6b42 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py @@ -200,7 +200,7 @@ def __init__( domains = AddDatasetDomain.get_domain_class(ctx.graph, config.domains) generic_config = AddDatasetDomainSemanticsConfig( get_domains_to_add=lambda _: domains, - **config.dict(exclude={"domains"}), + **config.model_dump(exclude={"domains"}), ) super().__init__(generic_config, ctx) diff --git a/metadata-ingestion/src/datahub/secret/datahub_secret_store.py b/metadata-ingestion/src/datahub/secret/datahub_secret_store.py index 63a704a498d205..7d6bb479d5b258 100644 --- a/metadata-ingestion/src/datahub/secret/datahub_secret_store.py +++ b/metadata-ingestion/src/datahub/secret/datahub_secret_store.py @@ -1,7 +1,7 @@ import logging from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, ConfigDict, field_validator from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.graph.config import DatahubClientConfig @@ -15,8 +15,7 @@ class DataHubSecretStoreConfig(BaseModel): graph_client: Optional[DataHubGraph] = None graph_client_config: Optional[DatahubClientConfig] = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) @field_validator("graph_client", mode="after") @classmethod diff --git a/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py b/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py index dfc1fd9c49c464..d1f67414be6d7d 100644 --- a/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py +++ b/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py @@ -57,7 +57,7 @@ def assert_sql_result_with_resolver( f"Missing expected golden file; run with --update-golden-files to create it: {expected_file}" ) - expected = SqlParsingResult.parse_raw(expected_file.read_text()) + expected = SqlParsingResult.model_validate_json(expected_file.read_text()) full_diff = deepdiff.DeepDiff( expected.model_dump(), diff --git a/metadata-ingestion/tests/unit/api/entities/datacontract/test_data_quality_assertion.py b/metadata-ingestion/tests/unit/api/entities/datacontract/test_data_quality_assertion.py index 7be8b667a500b3..dd465c10ad4706 100644 --- a/metadata-ingestion/tests/unit/api/entities/datacontract/test_data_quality_assertion.py +++ b/metadata-ingestion/tests/unit/api/entities/datacontract/test_data_quality_assertion.py @@ -26,7 +26,7 @@ def test_parse_sql_assertion(): "operator": {"type": "between", "min": 5, "max": 10}, } - assert DataQualityAssertion.parse_obj(d).generate_mcp( + assert DataQualityAssertion.model_validate(d).generate_mcp( assertion_urn, entity_urn ) == [ MetadataChangeProposalWrapper( diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_sql_common_state.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_sql_common_state.py index 51e978ad135a17..6cbbf8a84e393d 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_sql_common_state.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_sql_common_state.py @@ -87,5 +87,5 @@ def test_deduplication_and_order_preservation() -> None: ] # verifies that the state can be serialized without raising an error - json = state.json() + json = state.model_dump_json() assert json diff --git a/metadata-ingestion/tests/unit/test_athena_source.py b/metadata-ingestion/tests/unit/test_athena_source.py index 24a0f1aea1d9f4..4241f1a8ddc6a2 100644 --- a/metadata-ingestion/tests/unit/test_athena_source.py +++ b/metadata-ingestion/tests/unit/test_athena_source.py @@ -62,7 +62,7 @@ def test_athena_config_staging_dir_is_set_as_query_result(): } ) - assert config.json() == expected_config.json() + assert config.model_dump_json() == expected_config.model_dump_json() def test_athena_uri(): diff --git a/metadata-ingestion/tests/unit/test_mysql_rds_iam.py b/metadata-ingestion/tests/unit/test_mysql_rds_iam.py index c814b5ab3a2b91..0db5365220527e 100644 --- a/metadata-ingestion/tests/unit/test_mysql_rds_iam.py +++ b/metadata-ingestion/tests/unit/test_mysql_rds_iam.py @@ -14,7 +14,7 @@ def test_config_without_rds_iam(self): "password": "testpass", "database": "testdb", } - config = MySQLConfig.parse_obj(config_dict) + config = MySQLConfig.model_validate(config_dict) assert config.auth_mode == MySQLAuthMode.PASSWORD assert config.aws_config is not None # aws_config always has default value @@ -28,7 +28,7 @@ def test_config_with_rds_iam_valid(self): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = MySQLConfig.parse_obj(config_dict) + config = MySQLConfig.model_validate(config_dict) assert config.auth_mode == MySQLAuthMode.AWS_IAM assert config.aws_config is not None @@ -42,7 +42,7 @@ def test_config_with_rds_iam_without_explicit_aws_config(self): "database": "testdb", "auth_mode": "AWS_IAM", } - config = MySQLConfig.parse_obj(config_dict) + config = MySQLConfig.model_validate(config_dict) assert config.auth_mode == MySQLAuthMode.AWS_IAM assert config.aws_config is not None @@ -58,7 +58,7 @@ def test_init_without_rds_iam(self, mock_token_manager): "password": "testpass", "database": "testdb", } - config = MySQLConfig.parse_obj(config_dict) + config = MySQLConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") source = MySQLSource(config, ctx) @@ -75,7 +75,7 @@ def test_init_with_rds_iam(self, mock_token_manager): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = MySQLConfig.parse_obj(config_dict) + config = MySQLConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") source = MySQLSource(config, ctx) @@ -97,7 +97,7 @@ def test_init_with_rds_iam_custom_port(self, mock_token_manager): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = MySQLConfig.parse_obj(config_dict) + config = MySQLConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") MySQLSource(config, ctx) @@ -117,7 +117,7 @@ def test_init_with_rds_iam_no_username(self, mock_token_manager): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = MySQLConfig.parse_obj(config_dict) + config = MySQLConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") with pytest.raises(ValueError, match="username is required"): @@ -132,7 +132,7 @@ def test_init_with_rds_iam_invalid_port(self, mock_token_manager): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = MySQLConfig.parse_obj(config_dict) + config = MySQLConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") MySQLSource(config, ctx) diff --git a/metadata-ingestion/tests/unit/test_postgres_rds_iam.py b/metadata-ingestion/tests/unit/test_postgres_rds_iam.py index 2ceee2a9d4802a..eca9a340303159 100644 --- a/metadata-ingestion/tests/unit/test_postgres_rds_iam.py +++ b/metadata-ingestion/tests/unit/test_postgres_rds_iam.py @@ -18,7 +18,7 @@ def test_config_without_rds_iam(self): "password": "testpass", "database": "testdb", } - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) assert config.auth_mode == PostgresAuthMode.PASSWORD assert config.aws_config is not None # aws_config always has default value @@ -32,7 +32,7 @@ def test_config_with_rds_iam_valid(self): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) assert config.auth_mode == PostgresAuthMode.AWS_IAM assert config.aws_config is not None @@ -46,7 +46,7 @@ def test_config_with_rds_iam_without_explicit_aws_config(self): "database": "testdb", "auth_mode": "AWS_IAM", } - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) assert config.auth_mode == PostgresAuthMode.AWS_IAM assert config.aws_config is not None @@ -62,7 +62,7 @@ def test_init_without_rds_iam(self, mock_token_manager): "password": "testpass", "database": "testdb", } - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") source = PostgresSource(config, ctx) @@ -79,7 +79,7 @@ def test_init_with_rds_iam(self, mock_token_manager): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") source = PostgresSource(config, ctx) @@ -101,7 +101,7 @@ def test_init_with_rds_iam_custom_port(self, mock_token_manager): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") PostgresSource(config, ctx) @@ -121,7 +121,7 @@ def test_init_with_rds_iam_no_username(self, mock_token_manager): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") with pytest.raises(ValueError, match="username is required"): @@ -136,7 +136,7 @@ def test_init_with_rds_iam_invalid_port(self, mock_token_manager): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") PostgresSource(config, ctx) @@ -159,7 +159,7 @@ def test_init_with_rds_iam_stores_hostname_and_port(self, mock_token_manager): "auth_mode": "AWS_IAM", "aws_config": {"aws_region": "us-west-2"}, } - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test-run") source = PostgresSource(config, ctx) diff --git a/metadata-ingestion/tests/unit/test_unity_catalog_source.py b/metadata-ingestion/tests/unit/test_unity_catalog_source.py index dfe9df5d90a577..0b2b3d7b7d050e 100644 --- a/metadata-ingestion/tests/unit/test_unity_catalog_source.py +++ b/metadata-ingestion/tests/unit/test_unity_catalog_source.py @@ -25,7 +25,7 @@ def test_azure_auth_config_missing_fields(self, azure_auth_partial): "azure_auth": azure_auth_partial, } with pytest.raises(Exception) as exc_info: - UnityCatalogSourceConfig.parse_obj(config_dict) + UnityCatalogSourceConfig.model_validate(config_dict) # Should mention the missing field in the error message assert ( "client_id" in str(exc_info.value) @@ -76,7 +76,7 @@ def config_with_ml_model_settings(self): @pytest.fixture def config_with_azure_auth(self): """Create a config with Azure authentication.""" - return UnityCatalogSourceConfig.parse_obj( + return UnityCatalogSourceConfig.model_validate( { "workspace_url": "https://test.databricks.com", "warehouse_id": "test_warehouse", @@ -93,7 +93,7 @@ def config_with_azure_auth(self): @pytest.fixture def config_with_azure_auth_and_ml_models(self): """Create a config with Azure authentication and ML model settings.""" - return UnityCatalogSourceConfig.parse_obj( + return UnityCatalogSourceConfig.model_validate( { "workspace_url": "https://test.databricks.com", "warehouse_id": "test_warehouse", @@ -349,7 +349,7 @@ def test_azure_auth_config_validation(self): }, } - config = UnityCatalogSourceConfig.parse_obj(valid_config_dict) + config = UnityCatalogSourceConfig.model_validate(valid_config_dict) assert config.azure_auth is not None assert config.azure_auth.client_id == "test-client-id" assert config.azure_auth.tenant_id == "test-tenant-id" @@ -402,7 +402,7 @@ def test_source_creation_fails_without_authentication(self): """Test that UnityCatalogSource creation fails when neither token nor azure_auth are provided.""" # Test with neither token nor azure_auth provided - this should fail at config parsing with pytest.raises(ValueError) as exc_info: - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "workspace_url": "https://test.databricks.com", "warehouse_id": "test_warehouse", @@ -421,7 +421,7 @@ def test_source_creation_fails_without_authentication(self): def test_test_connection_fails_without_authentication(self): """Test that test_connection fails when neither token nor azure_auth are provided.""" with pytest.raises(ValueError) as exc_info: - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "workspace_url": "https://test.databricks.com", "warehouse_id": "test_warehouse", @@ -440,7 +440,7 @@ def test_source_creation_fails_with_both_authentication_methods(self): """Test that UnityCatalogSource creation fails when both token and azure_auth are provided.""" # Test with both token and azure_auth provided - this should fail at config parsing with pytest.raises(ValueError) as exc_info: - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "workspace_url": "https://test.databricks.com", "warehouse_id": "test_warehouse", @@ -587,7 +587,7 @@ def test_process_ml_model_version_with_run_details( Schema, ) - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "test_token", "workspace_url": "https://test.databricks.com", @@ -701,7 +701,7 @@ def test_process_ml_model_version_with_none_run_details( Schema, ) - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "test_token", "workspace_url": "https://test.databricks.com", @@ -796,7 +796,7 @@ def test_process_ml_model_version_with_partial_run_details( Schema, ) - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "test_token", "workspace_url": "https://test.databricks.com", diff --git a/smoke-test/tests/test_stateful_ingestion.py b/smoke-test/tests/test_stateful_ingestion.py index 41d3042a105229..0460d1168a5183 100644 --- a/smoke-test/tests/test_stateful_ingestion.py +++ b/smoke-test/tests/test_stateful_ingestion.py @@ -19,9 +19,9 @@ def test_stateful_ingestion(auth_session): def create_db_engine(sql_source_config_dict: Dict[str, Any]) -> Any: sql_config: Union[MySQLConfig, PostgresConfig] if get_db_type() == "mysql": - sql_config = MySQLConfig.parse_obj(sql_source_config_dict) + sql_config = MySQLConfig.model_validate(sql_source_config_dict) else: - sql_config = PostgresConfig.parse_obj(sql_source_config_dict) + sql_config = PostgresConfig.model_validate(sql_source_config_dict) url = sql_config.get_sql_alchemy_url() return create_engine(url) From 066fb30e925d8930f8ead7dd22031e1cace70874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20G=C3=B3mez=20Villamor?= Date: Wed, 12 Nov 2025 08:45:52 +0000 Subject: [PATCH 4/4] fix smoke-test lint --- smoke-test/tests/cypress/test_weights.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smoke-test/tests/cypress/test_weights.json b/smoke-test/tests/cypress/test_weights.json index d02f4ec2842d1b..f4f7571d589205 100644 --- a/smoke-test/tests/cypress/test_weights.json +++ b/smoke-test/tests/cypress/test_weights.json @@ -371,4 +371,4 @@ "filePath": "manage_tagsV2/search_bar_placeholder.js", "duration": "3.263s" } -] \ No newline at end of file +]