Skip to content

Commit ebb2a40

Browse files
fix: BaseFrame.filter with list[bool] in predicates (#3183)
Co-authored-by: dangotbanned <125183946+dangotbanned@users.noreply.github.com>
1 parent f9a4617 commit ebb2a40

File tree

5 files changed

+201
-70
lines changed

5 files changed

+201
-70
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from narwhals._arrow.namespace import ArrowNamespace
4242
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
4343
ChunkedArrayAny,
44-
Mask,
4544
Order,
4645
)
4746
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
@@ -518,12 +517,9 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
518517
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
519518
return self.select(row_index, plx.all())
520519

521-
def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
522-
if isinstance(predicate, list):
523-
mask_native: Mask | ChunkedArrayAny = predicate
524-
else:
525-
# `[0]` is safe as the predicate's expression only returns a single column
526-
mask_native = self._evaluate_into_exprs(predicate)[0].native
520+
def filter(self, predicate: ArrowExpr) -> Self:
521+
# `[0]` is safe as the predicate's expression only returns a single column
522+
mask_native = self._evaluate_into_exprs(predicate)[0].native
527523
return self._with_native(
528524
self.native.filter(mask_native), validate_column_names=False
529525
)

narwhals/_pandas_like/dataframe.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -461,13 +461,10 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
461461
def row(self, index: int) -> tuple[Any, ...]:
462462
return tuple(x for x in self.native.iloc[index])
463463

464-
def filter(self, predicate: PandasLikeExpr | list[bool]) -> Self:
465-
if isinstance(predicate, list):
466-
mask_native: pd.Series[Any] | list[bool] = predicate
467-
else:
468-
# `[0]` is safe as the predicate's expression only returns a single column
469-
mask = self._evaluate_into_exprs(predicate)[0]
470-
mask_native = self._extract_comparand(mask)
464+
def filter(self, predicate: PandasLikeExpr) -> Self:
465+
# `[0]` is safe as the predicate's expression only returns a single column
466+
mask = self._evaluate_into_exprs(predicate)[0]
467+
mask_native = self._extract_comparand(mask)
471468
return self._with_native(
472469
self.native.loc[mask_native], validate_column_names=False
473470
)

narwhals/_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import re
5+
import sys
56
from collections.abc import Collection, Container, Iterable, Iterator, Mapping, Sequence
67
from datetime import timezone
78
from enum import Enum, auto
@@ -690,6 +691,10 @@ def _is_iterable(arg: Any | Iterable[Any]) -> bool:
690691
return isinstance(arg, Iterable) and not isinstance(arg, (str, bytes, Series))
691692

692693

694+
def is_iterator(val: Iterable[_T] | Any) -> TypeIs[Iterator[_T]]:
695+
return isinstance(val, Iterator)
696+
697+
693698
def parse_version(version: str | ModuleType | _SupportsVersion) -> tuple[int, ...]:
694699
"""Simple version parser; split into a tuple of ints for comparison.
695700
@@ -1344,6 +1349,12 @@ def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[_T]]:
13441349
return bool(isinstance(obj, list) and obj and isinstance(obj[0], tp))
13451350

13461351

1352+
def predicates_contains_list_of_bool(
1353+
predicates: Collection[Any],
1354+
) -> TypeIs[Collection[list[bool]]]:
1355+
return any(is_list_of(pred, bool) for pred in predicates)
1356+
1357+
13471358
def is_sequence_of(obj: Any, tp: type[_T]) -> TypeIs[Sequence[_T]]:
13481359
# Check if an object is a sequence of `tp`, only sniffing the first element.
13491360
return bool(

narwhals/dataframe.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@
3535
is_compliant_lazyframe,
3636
is_eager_allowed,
3737
is_index_selector,
38+
is_iterator,
3839
is_lazy_allowed,
3940
is_list_of,
4041
is_sequence_like,
4142
is_slice_none,
43+
predicates_contains_list_of_bool,
4244
qualified_type_name,
4345
supports_arrow_c_stream,
4446
zip_strict,
@@ -242,24 +244,20 @@ def drop(self, *columns: Iterable[str], strict: bool) -> Self:
242244
return self._with_compliant(self._compliant_frame.drop(columns, strict=strict))
243245

244246
def filter(
245-
self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any
247+
self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any
246248
) -> Self:
247-
if len(predicates) == 1 and is_list_of(predicates[0], bool):
248-
predicate = predicates[0]
249-
else:
250-
from narwhals.functions import col
251-
252-
flat_predicates = flatten(predicates)
253-
check_expressions_preserve_length(*flat_predicates, function_name="filter")
254-
plx = self.__narwhals_namespace__()
255-
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
256-
compliant_constraints = (
257-
(col(name) == v)._to_compliant_expr(plx)
258-
for name, v in constraints.items()
259-
)
260-
predicate = plx.all_horizontal(
261-
*chain(compliant_predicates, compliant_constraints), ignore_nulls=False
262-
)
249+
from narwhals.functions import col
250+
251+
flat_predicates = flatten(predicates)
252+
check_expressions_preserve_length(*flat_predicates, function_name="filter")
253+
plx = self.__narwhals_namespace__()
254+
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
255+
compliant_constraints = (
256+
(col(name) == v)._to_compliant_expr(plx) for name, v in constraints.items()
257+
)
258+
predicate = plx.all_horizontal(
259+
*chain(compliant_predicates, compliant_constraints), ignore_nulls=False
260+
)
263261
return self._with_compliant(self._compliant_frame.filter(predicate))
264262

265263
def sort(
@@ -1653,7 +1651,7 @@ def filter(
16531651
16541652
Arguments:
16551653
*predicates: Expression(s) that evaluates to a boolean Series. Can
1656-
also be a (single!) boolean list.
1654+
also be a boolean list(s).
16571655
**constraints: Column filters; use `name = value` to filter columns by the supplied value.
16581656
Each constraint will behave the same as `nw.col(name).eq(value)`, and will be implicitly
16591657
joined with the other filter conditions using &.
@@ -1695,7 +1693,12 @@ def filter(
16951693
foo bar ham
16961694
1 2 7 b
16971695
"""
1698-
return super().filter(*predicates, **constraints)
1696+
impl = self.implementation
1697+
parsed_predicates = (
1698+
self._series.from_iterable("", p, backend=impl) if is_list_of(p, bool) else p
1699+
for p in predicates
1700+
)
1701+
return super().filter(*parsed_predicates, **constraints)
16991702

17001703
@overload
17011704
def group_by(
@@ -2850,15 +2853,14 @@ def unique(
28502853
)
28512854

28522855
def filter(
2853-
self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any
2856+
self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any
28542857
) -> Self:
28552858
r"""Filter the rows in the LazyFrame based on a predicate expression.
28562859
28572860
The original order of the remaining rows is preserved.
28582861
28592862
Arguments:
2860-
*predicates: Expression that evaluates to a boolean Series. Can
2861-
also be a (single!) boolean list.
2863+
*predicates: Expression(s) that evaluates to a boolean Series.
28622864
**constraints: Column filters; use `name = value` to filter columns by the supplied value.
28632865
Each constraint will behave the same as `nw.col(name).eq(value)`, and will be implicitly
28642866
joined with the other filter conditions using &.
@@ -2924,13 +2926,12 @@ def filter(
29242926
└───────┴───────┴─────────┘
29252927
<BLANKLINE>
29262928
"""
2927-
if (
2928-
len(predicates) == 1 and is_list_of(predicates[0], bool) and not constraints
2929-
): # pragma: no cover
2929+
predicates_ = tuple(tuple(p) if is_iterator(p) else p for p in predicates)
2930+
if predicates_contains_list_of_bool(predicates_):
29302931
msg = "`LazyFrame.filter` is not supported with Python boolean masks - use expressions instead."
29312932
raise TypeError(msg)
29322933

2933-
return super().filter(*predicates, **constraints)
2934+
return super().filter(*predicates_, **constraints)
29342935

29352936
def sink_parquet(self, file: str | Path | BytesIO) -> None:
29362937
"""Write LazyFrame to Parquet file.

0 commit comments

Comments
 (0)