diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 42d5117946..e386569102 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -102,7 +102,7 @@ def to_narwhals( if self._version is Version.V1: from narwhals.stable.v1 import DataFrame as DataFrameV1 - return DataFrameV1(self, level="interchange") # type: ignore[no-any-return] + return DataFrameV1(self, level="interchange") return self._version.lazyframe(self, level="lazy") def __narwhals_dataframe__(self) -> Self: # pragma: no cover @@ -116,7 +116,7 @@ def __narwhals_lazyframe__(self) -> Self: return self def __native_namespace__(self) -> ModuleType: - return get_duckdb() # type: ignore[no-any-return] + return get_duckdb() def __narwhals_namespace__(self) -> DuckDBNamespace: from narwhals._duckdb.namespace import DuckDBNamespace @@ -138,12 +138,8 @@ def collect( if backend is None or backend is Implementation.PYARROW: from narwhals._arrow.dataframe import ArrowDataFrame - if self._backend_version < (1, 4): - ret = self.native.arrow() - else: # pragma: no cover - ret = self.native.fetch_arrow_table() return ArrowDataFrame( - ret, + self.native.fetch_arrow_table(), validate_backend_version=True, version=self._version, validate_column_names=True, @@ -260,7 +256,7 @@ def to_pandas(self) -> pd.DataFrame: def to_arrow(self) -> pa.Table: # only if version is v1, keep around for backcompat - return self.lazy().collect(Implementation.PYARROW).native # type: ignore[no-any-return] + return self.lazy().collect(Implementation.PYARROW).native def _with_version(self, version: Version) -> Self: return self.__class__(self.native, version=version) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 418132b6a9..b3286850b9 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -256,7 +256,8 @@ def is_finite(self) -> Self: return self._with_elementwise(lambda expr: F("isfinite", expr)) def is_in(self, other: Sequence[Any]) -> Self: - return self._with_elementwise(lambda expr: F("contains", lit(other), expr)) + other_ = tuple(other) if not isinstance(other, (tuple, list)) else other + return self._with_elementwise(lambda expr: F("contains", lit(other_), expr)) def fill_null( self, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 09b5ecd8eb..827aab0936 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -37,6 +37,7 @@ from narwhals._utils import Version from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral +BIGINT = duckdb_dtypes.BIGINT VARCHAR = duckdb_dtypes.VARCHAR @@ -123,9 +124,7 @@ def mean_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr: def func(cols: Iterable[Expression]) -> Expression: cols = tuple(cols) total = reduce(operator.add, (CoalesceOperator(col, lit(0)) for col in cols)) - count = reduce( - operator.add, (col.isnotnull().cast(duckdb_dtypes.BIGINT) for col in cols) - ) + count = reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols)) return total / count return self._expr._from_elementwise_horizontal_op(func, *exprs) diff --git a/narwhals/_duckdb/series.py b/narwhals/_duckdb/series.py index 5b284b95c3..91f3fc7ecb 100644 --- a/narwhals/_duckdb/series.py +++ b/narwhals/_duckdb/series.py @@ -24,7 +24,7 @@ def __narwhals_series__(self) -> Self: return self def __native_namespace__(self) -> ModuleType: - return get_duckdb() # type: ignore[no-any-return] + return get_duckdb() @property def dtype(self) -> DType: diff --git a/narwhals/_duckdb/typing.py b/narwhals/_duckdb/typing.py index cbd8d16847..43ca5a0509 100644 --- a/narwhals/_duckdb/typing.py +++ b/narwhals/_duckdb/typing.py @@ -1,11 +1,44 @@ from __future__ import annotations -from typing import TYPE_CHECKING, TypedDict +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, Union, overload + +import duckdb +from duckdb import Expression + +from narwhals._typing_compat import TypeVar if TYPE_CHECKING: - from collections.abc import Sequence + import uuid + + import numpy as np + import pandas as pd + from duckdb import DuckDBPyConnection + from typing_extensions import TypeAlias, TypeIs + + from narwhals.typing import Into1DArray, PythonLiteral - from duckdb import Expression + +__all__ = [ + "BaseType", + "IntoColumnExpr", + "WindowExpressionKwargs", + "has_children", + "is_dtype", +] + +IntoDuckDBLiteral: TypeAlias = """ + PythonLiteral + | dict[Any, Any] + | uuid.UUID + | bytearray + | memoryview + | Into1DArray + | pd.api.typing.NaTType + | pd.api.typing.NAType + | np.ma.MaskedArray + | duckdb.Value + """ class WindowExpressionKwargs(TypedDict, total=False): @@ -16,3 +49,102 @@ class WindowExpressionKwargs(TypedDict, total=False): descending: Sequence[bool] nulls_last: Sequence[bool] ignore_nulls: bool + + +_Children_co = TypeVar( + "_Children_co", + covariant=True, + bound=Sequence[tuple[str, Any]], + default=Sequence[tuple[str, Any]], +) +DTypeT_co = TypeVar("DTypeT_co", covariant=True, bound="BaseType", default="BaseType") +_Child: TypeAlias = tuple[Literal["child"], DTypeT_co] +_Size: TypeAlias = tuple[Literal["size"], int] +_ID_co = TypeVar("_ID_co", bound=str, default=str, covariant=True) +_Array: TypeAlias = Literal["array"] +_Struct: TypeAlias = Literal["struct"] +_List: TypeAlias = Literal["list"] +_Enum: TypeAlias = Literal["enum"] +_Decimal: TypeAlias = Literal["decimal"] +_TimestampTZ: TypeAlias = Literal["timestamp with time zone"] +IntoColumnExpr: TypeAlias = Union[str, Expression] +"""A column name, or the result of calling `duckdb.ColumnExpression`.""" + + +class BaseType(Protocol[_ID_co]): + """Structural equivalent to [`DuckDBPyType`]. + + Excludes attributes which are unsafe to use on most types. + + [`DuckDBPyType`]: https://github.com/duckdb/duckdb-python/blob/df7789cbd31b2d2b8d03d012f14331bc3297fb2d/_duckdb-stubs/_sqltypes.pyi#L35-L75 + """ + + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + @overload + def __init__(self, type_str: str, connection: DuckDBPyConnection) -> None: ... + @overload + def __init__(self, obj: object) -> None: ... + @property + def id(self) -> _ID_co: ... + + +def has_children( + dtype: BaseType | _ParentType[_ID_co, _Children_co], +) -> TypeIs[_ParentType[_ID_co, _Children_co]]: + """Return True if `dtype.children` can be accessed safely. + + `_hasattr_static` returns True on *any* [`DuckDBPyType`], so the only way to be sure is by forcing an exception. + + [`DuckDBPyType`]: https://github.com/duckdb/duckdb-python/blob/df7789cbd31b2d2b8d03d012f14331bc3297fb2d/_duckdb-stubs/_sqltypes.pyi#L35-L75 + """ + try: + return hasattr(dtype, "children") + except duckdb.InvalidInputException: + return False + + +@overload +def is_dtype(obj: BaseType, type_id: _Array, /) -> TypeIs[ArrayType]: ... +@overload +def is_dtype(obj: BaseType, type_id: _Struct, /) -> TypeIs[StructType]: ... +@overload +def is_dtype(obj: BaseType, type_id: _List, /) -> TypeIs[ListType]: ... +@overload +def is_dtype(obj: BaseType, type_id: _Enum, /) -> TypeIs[EnumType]: ... +@overload +def is_dtype(obj: BaseType, type_id: _Decimal, /) -> TypeIs[DecimalType]: ... +@overload +def is_dtype( + obj: BaseType, type_id: _TimestampTZ, / +) -> TypeIs[BaseType[_TimestampTZ]]: ... +def is_dtype( + obj: BaseType, type_id: _Array | _Struct | _List | _Enum | _Decimal | _TimestampTZ, / +) -> bool: + """Return True if `obj` is the [`DuckDBPyType`] corresponding with `type_id`. + + [`DuckDBPyType`]: https://github.com/duckdb/duckdb-python/blob/df7789cbd31b2d2b8d03d012f14331bc3297fb2d/_duckdb-stubs/_sqltypes.pyi#L35-L75 + """ + return obj.id == type_id + + +class _ParentType(BaseType[_ID_co], Protocol[_ID_co, _Children_co]): + @property + def children(self) -> _Children_co: ... + + +ArrayType: TypeAlias = _ParentType[_Array, tuple[_Child[DTypeT_co], _Size]] +EnumType: TypeAlias = _ParentType[_Enum, tuple[tuple[Literal["values"], list[str]]]] +DecimalType: TypeAlias = _ParentType[ + _Decimal, tuple[tuple[Literal["precision"], int], tuple[Literal["scale"], int]] +] + + +class ListType(_ParentType[_List, tuple[_Child[DTypeT_co]]], Protocol[DTypeT_co]): + @property + def child(self) -> DTypeT_co: ... + + +class StructType(_ParentType[_Struct, Sequence[tuple[str, BaseType]]], Protocol): + def __getattr__(self, name: str) -> BaseType: ... + def __getitem__(self, name: str) -> BaseType: ... diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 6304ba86d7..10141aafa8 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -1,31 +1,50 @@ from __future__ import annotations from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import duckdb from duckdb import Expression -try: - import duckdb.sqltypes as duckdb_dtypes -except ModuleNotFoundError: - # DuckDB pre 1.3 - import duckdb.typing as duckdb_dtypes - -from narwhals._utils import Version, extend_bool, isinstance_or_issubclass, zip_strict +from narwhals._duckdb.typing import ( + BaseType, + IntoColumnExpr, + IntoDuckDBLiteral, + has_children, + is_dtype, +) +from narwhals._utils import ( + Implementation, + Version, + extend_bool, + isinstance_or_issubclass, + zip_strict, +) from narwhals.exceptions import ColumnNotFoundError if TYPE_CHECKING: from collections.abc import Mapping, Sequence from duckdb import DuckDBPyRelation + from duckdb.sqltypes import DuckDBPyType + from typing_extensions import TypeAlias + import narwhals._duckdb.typing from narwhals._compliant.typing import CompliantLazyFrameAny from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.expr import DuckDBExpr from narwhals.dtypes import DType from narwhals.typing import IntoDType, TimeUnit +Incomplete: TypeAlias = Any + +BACKEND_VERSION = Implementation.DUCKDB._backend_version() +"""Static backend version for `duckdb`.""" + +if TYPE_CHECKING or BACKEND_VERSION >= (1, 4): + from duckdb import sqltypes as duckdb_dtypes +else: # pragma: no cover + from duckdb import typing as duckdb_dtypes UNITS_DICT = { "y": "year", @@ -45,8 +64,14 @@ col = duckdb.ColumnExpression """Alias for `duckdb.ColumnExpression`.""" -lit = duckdb.ConstantExpression -"""Alias for `duckdb.ConstantExpression`.""" + +# TODO @dangotbanned: Raise an issue upstream on `Expression | str` too narrow +# NOTE: https://github.com/duckdb/duckdb-python/blob/df7789cbd31b2d2b8d03d012f14331bc3297fb2d/src/duckdb_py/native/python_conversion.cpp#L916-L1069 +def lit(value: IntoDuckDBLiteral | Expression) -> Expression: + """Alias for `duckdb.ConstantExpression`.""" + lit_: Incomplete = duckdb.ConstantExpression + return lit_(value) + when = duckdb.CaseExpression """Alias for `duckdb.CaseExpression`.""" @@ -55,8 +80,10 @@ """Alias for `duckdb.FunctionExpression`.""" +# TODO @dangotbanned: Raise an issue upstream on `Expression | str | tuple[str` too narrow +# NOTE: https://github.com/duckdb/duckdb-python/blob/df7789cbd31b2d2b8d03d012f14331bc3297fb2d/src/duckdb_py/pyexpression.cpp#L361-L413 def lambda_expr( - params: str | Expression | tuple[Expression, ...], expr: Expression, / + params: IntoColumnExpr | tuple[IntoColumnExpr, ...], expr: Expression, / ) -> Expression: """Wraps [`duckdb.LambdaExpression`]. @@ -68,7 +95,8 @@ def lambda_expr( msg = f"DuckDB>=1.2.0 is required for this operation. Found: DuckDB {duckdb.__version__}" raise NotImplementedError(msg) from exc args = (params,) if isinstance(params, Expression) else params - return LambdaExpression(args, expr) + lambda_expr_: Incomplete = LambdaExpression + return lambda_expr_(args, expr) def concat_str(*exprs: Expression, separator: str = "") -> Expression: @@ -135,20 +163,27 @@ def time_zone(self) -> str: def native_to_narwhals_dtype( - duckdb_dtype: duckdb_dtypes.DuckDBPyType, + duckdb_dtype: BaseType, version: Version, deferred_time_zone: DeferredTimeZone +) -> DType: + if has_children(duckdb_dtype) and not is_dtype(duckdb_dtype, "decimal"): + return _nested_native_to_narwhals_dtype(duckdb_dtype, version, deferred_time_zone) + if is_dtype(duckdb_dtype, "timestamp with time zone"): + return version.dtypes.Datetime(time_zone=deferred_time_zone.time_zone) + return _non_nested_native_to_narwhals_dtype(duckdb_dtype.id, version) + + +def _nested_native_to_narwhals_dtype( + duckdb_dtype: narwhals._duckdb.typing._ParentType, version: Version, deferred_time_zone: DeferredTimeZone, ) -> DType: - duckdb_dtype_id = duckdb_dtype.id dtypes = version.dtypes - # Handle nested data types first - if duckdb_dtype_id == "list": + if is_dtype(duckdb_dtype, "list"): return dtypes.List( native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone) ) - - if duckdb_dtype_id == "struct": + if is_dtype(duckdb_dtype, "struct"): children = duckdb_dtype.children return dtypes.Struct( [ @@ -159,28 +194,24 @@ def native_to_narwhals_dtype( for child in children ] ) - - if duckdb_dtype_id == "array": + if is_dtype(duckdb_dtype, "array"): child, size = duckdb_dtype.children shape: list[int] = [size[1]] - while child[1].id == "array": + while is_dtype(child[1], "array"): child, size = child[1].children shape.insert(0, size[1]) inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone) return dtypes.Array(inner=inner, shape=tuple(shape)) - - if duckdb_dtype_id == "enum": + if is_dtype(duckdb_dtype, "enum"): if version is Version.V1: - return dtypes.Enum() # type: ignore[call-arg] + return dtypes.Enum() # pyright: ignore[reportCallIssue] categories = duckdb_dtype.children[0][1] return dtypes.Enum(categories=categories) - - if duckdb_dtype_id == "timestamp with time zone": - return dtypes.Datetime(time_zone=deferred_time_zone.time_zone) - - return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version) + # TODO @dangotbanned: Get coverage during https://github.com/narwhals-dev/narwhals/issues/3197 + # `MAP`, `UNION` + return dtypes.Unknown() # pragma: no cover def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str: @@ -188,7 +219,7 @@ def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str: "duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'" ).fetchone() assert result is not None # noqa: S101 - return result[0] # type: ignore[no-any-return] + return result[0] @lru_cache(maxsize=16) @@ -222,7 +253,7 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) dtypes = Version.MAIN.dtypes -NW_TO_DUCKDB_DTYPES: Mapping[type[DType], duckdb_dtypes.DuckDBPyType] = { +NW_TO_DUCKDB_DTYPES: Mapping[type[DType], DuckDBPyType] = { dtypes.Float64: duckdb_dtypes.DOUBLE, dtypes.Float32: duckdb_dtypes.FLOAT, dtypes.Binary: duckdb_dtypes.BLOB, @@ -241,7 +272,7 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) dtypes.UInt64: duckdb_dtypes.UBIGINT, dtypes.UInt128: duckdb_dtypes.UHUGEINT, } -TIME_UNIT_TO_TIMESTAMP: Mapping[TimeUnit, duckdb_dtypes.DuckDBPyType] = { +TIME_UNIT_TO_TIMESTAMP: Mapping[TimeUnit, DuckDBPyType] = { "s": duckdb_dtypes.TIMESTAMP_S, "ms": duckdb_dtypes.TIMESTAMP_MS, "us": duckdb_dtypes.TIMESTAMP, @@ -252,7 +283,7 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) def narwhals_to_native_dtype( # noqa: PLR0912, C901 dtype: IntoDType, version: Version, deferred_time_zone: DeferredTimeZone -) -> duckdb_dtypes.DuckDBPyType: +) -> DuckDBPyType: dtypes = version.dtypes base_type = dtype.base_type() if duckdb_type := NW_TO_DUCKDB_DTYPES.get(base_type): diff --git a/pyproject.toml b/pyproject.toml index 002e4bd024..cd09557aaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ extra = [ # heavier dependencies we don't necessarily need in every testing job "scikit-learn", ] typing = [ # keep some of these pinned and bump periodically so there's fewer surprises for contributors - "duckdb==1.3.0", + "duckdb==1.4.1", "hypothesis", "pytest", "pandas-stubs==2.3.0.250703", @@ -71,10 +71,11 @@ typing = [ # keep some of these pinned and bump periodically so there's fewer s "mypy~=1.15.0", "pyright", "pyarrow-stubs==19.2", - "sqlframe", + "sqlframe>=3.43.5", "polars==1.34.0", "uv", "narwhals[ibis]", + "ibis-framework==11.0.0", ] typing-ci = [ "narwhals[dask,modin]", @@ -322,7 +323,6 @@ module = [ "cupy.*", "dask.*", "dask_expr.*", - "duckdb.*", # https://github.com/ibis-project/ibis/issues/6844 "ibis.*", "joblib.*", @@ -344,6 +344,7 @@ module = [ "narwhals._arrow.*", "narwhals._dask.*", "narwhals._spark_like.*", + "narwhals._duckdb.*" ] warn_return_any = false diff --git a/tests/expr_and_series/is_in_test.py b/tests/expr_and_series/is_in_test.py index 2ae6cabea5..e91ff2e8c8 100644 --- a/tests/expr_and_series/is_in_test.py +++ b/tests/expr_and_series/is_in_test.py @@ -26,6 +26,18 @@ def test_expr_is_in_empty_list(constructor: Constructor) -> None: assert_equal_data(result, expected) +def test_expr_is_in_iterable( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if any(x in str(constructor) for x in ("sqlframe", "polars")): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) + sequence = 4, 2 + result = df.select(nw.col("a").is_in(iter(sequence))) + expected = {"a": [False, True, True, False]} + assert_equal_data(result, expected) + + def test_ser_is_in(constructor_eager: ConstructorEager) -> None: ser = nw.from_native(constructor_eager(data), eager_only=True)["a"] result = {"a": ser.is_in([4, 5])}