Skip to content

Commit 786bd77

Browse files
committed
feat: Support int_range(eager=...)
May need tweaking pending (#2895 (comment))
1 parent ad32568 commit 786bd77

File tree

6 files changed

+164
-19
lines changed

6 files changed

+164
-19
lines changed

narwhals/_plan/arrow/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from __future__ import annotations
2+
3+
from narwhals._plan.arrow.dataframe import ArrowDataFrame as DataFrame
4+
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
5+
from narwhals._plan.arrow.namespace import ArrowNamespace as Namespace
6+
from narwhals._plan.arrow.series import ArrowSeries as Series
7+
8+
__all__ = ["DataFrame", "Expr", "Namespace", "Scalar", "Series"]

narwhals/_plan/arrow/namespace.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from narwhals._plan.compliant.namespace import EagerNamespace
1414
from narwhals._plan.expressions.literal import is_literal_scalar
1515
from narwhals._typing_compat import TypeVar
16-
from narwhals._utils import Version
16+
from narwhals._utils import Implementation, Version
1717
from narwhals.exceptions import InvalidOperationError
1818

1919
if TYPE_CHECKING:
@@ -23,19 +23,24 @@
2323
from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame
2424
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
2525
from narwhals._plan.arrow.series import ArrowSeries as Series
26+
from narwhals._plan.arrow.typing import ChunkedArray, IntegerScalar
2627
from narwhals._plan.expressions import expr, functions as F
2728
from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal
2829
from narwhals._plan.expressions.expr import FunctionExpr, RangeExpr
2930
from narwhals._plan.expressions.ranges import DateRange, IntRange
3031
from narwhals._plan.expressions.strings import ConcatStr
3132
from narwhals._plan.series import Series as NwSeries
33+
from narwhals.dtypes import IntegerType
3234
from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral
3335

3436

3537
PythonLiteralT = TypeVar("PythonLiteralT", bound="PythonLiteral")
38+
Int64 = Version.MAIN.dtypes.Int64()
3639

3740

3841
class ArrowNamespace(EagerNamespace["Frame", "Series", "Expr", "Scalar"]):
42+
implementation = Implementation.PYARROW
43+
3944
def __init__(self, version: Version = Version.MAIN) -> None:
4045
self._version = version
4146

@@ -186,14 +191,33 @@ def _range_function_inputs(
186191
msg = f"All inputs for `{node.function}()` resolve to {valid_type.__name__}, but got \n{start_!r}\n{end_!r}"
187192
raise InvalidOperationError(msg)
188193

194+
def _int_range(
195+
self, start: int, end: int, step: int, dtype: IntegerType, /
196+
) -> ChunkedArray[IntegerScalar]:
197+
if dtype is not Int64:
198+
pa_dtype = narwhals_to_native_dtype(dtype, self.version)
199+
if not pa.types.is_integer(pa_dtype):
200+
raise TypeError(dtype)
201+
return fn.int_range(start, end, step, dtype=pa_dtype)
202+
return fn.int_range(start, end, step)
203+
189204
def int_range(self, node: RangeExpr[IntRange], frame: Frame, name: str) -> Expr:
190205
start, end = self._range_function_inputs(node, frame, int)
191-
dtype = narwhals_to_native_dtype(node.function.dtype, self.version)
192-
if not pa.types.is_integer(dtype):
193-
raise TypeError(dtype)
194-
native = fn.int_range(start, end, node.function.step, dtype=dtype)
206+
native = self._int_range(start, end, node.function.step, node.function.dtype)
195207
return self._expr.from_native(native, name, self.version)
196208

209+
def int_range_eager(
210+
self,
211+
start: int,
212+
end: int,
213+
step: int = 1,
214+
*,
215+
dtype: IntegerType = Int64,
216+
name: str = "literal",
217+
) -> Series:
218+
native = self._int_range(start, end, step, dtype)
219+
return self._series.from_native(native, name, version=self.version)
220+
197221
def date_range(self, node: RangeExpr[DateRange], frame: Frame, name: str) -> Expr:
198222
start, end = self._range_function_inputs(node, frame, dt.date)
199223
func = node.function

narwhals/_plan/compliant/namespace.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Literal, Protocol, overload
3+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload
44

55
from narwhals._plan.compliant.typing import (
66
ConcatT1,
@@ -16,6 +16,7 @@
1616
ScalarT_co,
1717
SeriesT,
1818
)
19+
from narwhals._utils import Implementation, Version
1920

2021
if TYPE_CHECKING:
2122
from collections.abc import Iterable
@@ -27,10 +28,15 @@
2728
from narwhals._plan.expressions.ranges import DateRange, IntRange
2829
from narwhals._plan.expressions.strings import ConcatStr
2930
from narwhals._plan.series import Series
31+
from narwhals.dtypes import IntegerType
3032
from narwhals.typing import ConcatMethod, NonNestedLiteral
3133

34+
Int64 = Version.MAIN.dtypes.Int64()
35+
3236

3337
class CompliantNamespace(HasVersion, Protocol[FrameT, ExprT_co, ScalarT_co]):
38+
implementation: ClassVar[Implementation]
39+
3440
@property
3541
def _expr(self) -> type[ExprT_co]: ...
3642
@property
@@ -131,6 +137,15 @@ def lit(
131137
def lit(
132138
self, node: ir.Literal[Any], frame: EagerDataFrameT, name: str
133139
) -> EagerExprT_co | EagerScalarT_co: ...
140+
def int_range_eager(
141+
self,
142+
start: int,
143+
end: int,
144+
step: int = 1,
145+
*,
146+
dtype: IntegerType = Int64,
147+
name: str = "literal",
148+
) -> SeriesT: ...
134149

135150

136151
class LazyNamespace(

narwhals/_plan/functions.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,27 @@
1212
from narwhals._plan.expressions.ranges import DateRange, IntRange
1313
from narwhals._plan.expressions.strings import ConcatStr
1414
from narwhals._plan.when_then import When
15-
from narwhals._utils import Version, flatten
16-
from narwhals.exceptions import ComputeError
15+
from narwhals._utils import Implementation, Version, flatten, is_eager_allowed
16+
from narwhals.exceptions import ComputeError, InvalidOperationError
1717

1818
if TYPE_CHECKING:
19+
import pyarrow as pa
20+
21+
from narwhals._plan import arrow as _arrow
22+
from narwhals._plan.compliant.namespace import EagerNamespace
23+
from narwhals._plan.compliant.series import CompliantSeries
1924
from narwhals._plan.expr import Expr
2025
from narwhals._plan.series import Series
2126
from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT
27+
from narwhals._typing import Arrow
2228
from narwhals.dtypes import IntegerType
23-
from narwhals.typing import ClosedInterval, IntoDType, NonNestedLiteral
29+
from narwhals.typing import (
30+
ClosedInterval,
31+
EagerAllowed,
32+
IntoBackend,
33+
IntoDType,
34+
NonNestedLiteral,
35+
)
2436

2537

2638
def col(*names: str | t.Iterable[str]) -> Expr:
@@ -145,27 +157,94 @@ def when(
145157
return When._from_ir(condition)
146158

147159

160+
@t.overload
161+
def int_range(
162+
start: int | IntoExprColumn = ...,
163+
end: int | IntoExprColumn | None = ...,
164+
step: int = ...,
165+
*,
166+
dtype: IntegerType | type[IntegerType] = ...,
167+
eager: t.Literal[False] = ...,
168+
) -> Expr: ...
169+
@t.overload
170+
def int_range(
171+
start: int = ...,
172+
end: int | None = ...,
173+
step: int = ...,
174+
*,
175+
dtype: IntegerType | type[IntegerType] = ...,
176+
eager: Arrow,
177+
) -> Series[pa.ChunkedArray[t.Any]]: ...
178+
@t.overload
179+
def int_range(
180+
start: int = ...,
181+
end: int | None = ...,
182+
step: int = ...,
183+
*,
184+
dtype: IntegerType | type[IntegerType] = ...,
185+
eager: IntoBackend[EagerAllowed],
186+
) -> Series: ...
148187
def int_range(
149188
start: int | IntoExprColumn = 0,
150189
end: int | IntoExprColumn | None = None,
151190
step: int = 1,
152191
*,
153192
dtype: IntegerType | type[IntegerType] = Version.MAIN.dtypes.Int64,
154-
eager: bool = False,
155-
) -> Expr:
193+
eager: IntoBackend[EagerAllowed] | t.Literal[False] = False,
194+
) -> Expr | Series:
156195
if end is None:
157196
end = start
158197
start = 0
198+
dtype = common.into_dtype(dtype)
159199
if eager:
160-
msg = f"{eager=}"
161-
raise NotImplementedError(msg)
200+
return _int_range_eager(start, end, step, dtype=dtype, ns=_eager_namespace(eager))
162201
return (
163-
IntRange(step=step, dtype=common.into_dtype(dtype))
202+
IntRange(step=step, dtype=dtype)
164203
.to_function_expr(*_parse.parse_into_seq_of_expr_ir(start, end))
165204
.to_narwhals()
166205
)
167206

168207

208+
def _int_range_eager(
209+
start: t.Any,
210+
end: t.Any,
211+
step: int,
212+
*,
213+
dtype: IntegerType,
214+
ns: EagerNamespace[t.Any, CompliantSeries[NativeSeriesT], t.Any, t.Any],
215+
) -> Series[NativeSeriesT]:
216+
if not (isinstance(start, int) and isinstance(end, int)):
217+
msg = (
218+
f"Expected `start` and `end` to be integer values since `eager={ns.implementation}`.\n"
219+
f"Found: `start` of type {type(start)} and `end` of type {type(end)}\n\n"
220+
"Hint: Calling `nw.int_range` with expressions requires:\n"
221+
" - `eager=False`"
222+
" - a context such as `select` or `with_columns`"
223+
)
224+
raise InvalidOperationError(msg)
225+
return ns.int_range_eager(start, end, step, dtype=dtype).to_narwhals()
226+
227+
228+
@t.overload
229+
def _eager_namespace(backend: Arrow, /) -> _arrow.Namespace: ...
230+
@t.overload
231+
def _eager_namespace(
232+
backend: IntoBackend[EagerAllowed], /
233+
) -> EagerNamespace[t.Any, t.Any, t.Any, t.Any]: ...
234+
def _eager_namespace(
235+
backend: IntoBackend[EagerAllowed], /
236+
) -> EagerNamespace[t.Any, t.Any, t.Any, t.Any] | _arrow.Namespace:
237+
impl = Implementation.from_backend(backend)
238+
if is_eager_allowed(impl):
239+
if impl is Implementation.PYARROW:
240+
from narwhals._plan.arrow.namespace import ArrowNamespace
241+
242+
return ArrowNamespace(Version.MAIN)
243+
raise NotImplementedError(impl)
244+
msg = f"{impl} support in Narwhals is lazy-only"
245+
raise ValueError(msg)
246+
247+
169248
def date_range(
170249
start: dt.date | IntoExprColumn,
171250
end: dt.date | IntoExprColumn,

tests/plan/compliant_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,3 +658,11 @@ def test_dataframe_from_native_overloads() -> None:
658658
native_bad = native_good.to_batches()[0]
659659
nwp.DataFrame.from_native(native_bad) # type: ignore[call-overload]
660660
assert_type(native_bad, "pa.RecordBatch")
661+
662+
def test_int_range_overloads() -> None:
663+
series = nwp.int_range(50, eager="pyarrow")
664+
assert_type(series, "nwp.Series[pa.ChunkedArray[Any]]")
665+
native = series.to_native()
666+
assert_type(native, "pa.ChunkedArray[Any]")
667+
roundtrip = nwp.Series.from_native(native)
668+
assert_type(roundtrip, "nwp.Series[pa.ChunkedArray[Any]]")

tests/plan/expr_parsing_test.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from narwhals._plan.expressions import functions as F, operators as ops
1717
from narwhals._plan.expressions.literal import SeriesLiteral
1818
from narwhals._plan.expressions.ranges import IntRange
19+
from narwhals._utils import Implementation
1920
from narwhals.exceptions import (
2021
ComputeError,
2122
InvalidIntoExprError,
@@ -186,11 +187,21 @@ def test_int_range_invalid() -> None:
186187
int_range.to_function_expr(ir.col("a"))
187188

188189

189-
@pytest.mark.xfail(
190-
reason="Not implemented `int_range(eager=True)`", raises=NotImplementedError
191-
)
192-
def test_int_range_series() -> None:
193-
assert isinstance(nwp.int_range(50, eager=True), nwp.Series)
190+
def test_int_range_eager() -> None:
191+
series = nwp.int_range(50, eager="pyarrow")
192+
assert isinstance(series, nwp.Series)
193+
assert series.to_list() == list(range(50))
194+
series = nwp.int_range(50, eager=Implementation.PYARROW)
195+
assert series.to_list() == list(range(50))
196+
197+
with pytest.raises(InvalidOperationError):
198+
nwp.int_range(nwp.len(), eager="pyarrow") # type: ignore[call-overload]
199+
with pytest.raises(InvalidOperationError):
200+
nwp.int_range(10, nwp.col("a").last(), eager=Implementation.PYARROW) # type: ignore[call-overload]
201+
with pytest.raises(NotImplementedError):
202+
nwp.int_range(10, eager="pandas")
203+
with pytest.raises(ValueError, match=r"lazy-only"):
204+
nwp.int_range(10, eager="duckdb") # type: ignore[call-overload]
194205

195206

196207
def test_over_invalid() -> None:

0 commit comments

Comments
 (0)