|
35 | 35 | is_compliant_lazyframe, |
36 | 36 | is_eager_allowed, |
37 | 37 | is_index_selector, |
| 38 | + is_iterator, |
38 | 39 | is_lazy_allowed, |
39 | 40 | is_list_of, |
40 | 41 | is_sequence_like, |
41 | 42 | is_slice_none, |
| 43 | + predicates_contains_list_of_bool, |
42 | 44 | qualified_type_name, |
43 | 45 | supports_arrow_c_stream, |
44 | 46 | zip_strict, |
@@ -242,24 +244,20 @@ def drop(self, *columns: Iterable[str], strict: bool) -> Self: |
242 | 244 | return self._with_compliant(self._compliant_frame.drop(columns, strict=strict)) |
243 | 245 |
|
244 | 246 | def filter( |
245 | | - self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any |
| 247 | + self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any |
246 | 248 | ) -> Self: |
247 | | - if len(predicates) == 1 and is_list_of(predicates[0], bool): |
248 | | - predicate = predicates[0] |
249 | | - else: |
250 | | - from narwhals.functions import col |
251 | | - |
252 | | - flat_predicates = flatten(predicates) |
253 | | - check_expressions_preserve_length(*flat_predicates, function_name="filter") |
254 | | - plx = self.__narwhals_namespace__() |
255 | | - compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates) |
256 | | - compliant_constraints = ( |
257 | | - (col(name) == v)._to_compliant_expr(plx) |
258 | | - for name, v in constraints.items() |
259 | | - ) |
260 | | - predicate = plx.all_horizontal( |
261 | | - *chain(compliant_predicates, compliant_constraints), ignore_nulls=False |
262 | | - ) |
| 249 | + from narwhals.functions import col |
| 250 | + |
| 251 | + flat_predicates = flatten(predicates) |
| 252 | + check_expressions_preserve_length(*flat_predicates, function_name="filter") |
| 253 | + plx = self.__narwhals_namespace__() |
| 254 | + compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates) |
| 255 | + compliant_constraints = ( |
| 256 | + (col(name) == v)._to_compliant_expr(plx) for name, v in constraints.items() |
| 257 | + ) |
| 258 | + predicate = plx.all_horizontal( |
| 259 | + *chain(compliant_predicates, compliant_constraints), ignore_nulls=False |
| 260 | + ) |
263 | 261 | return self._with_compliant(self._compliant_frame.filter(predicate)) |
264 | 262 |
|
265 | 263 | def sort( |
@@ -1653,7 +1651,7 @@ def filter( |
1653 | 1651 |
|
1654 | 1652 | Arguments: |
1655 | 1653 | *predicates: Expression(s) that evaluates to a boolean Series. Can |
1656 | | - also be a (single!) boolean list. |
| 1654 | + also be a boolean list(s). |
1657 | 1655 | **constraints: Column filters; use `name = value` to filter columns by the supplied value. |
1658 | 1656 | Each constraint will behave the same as `nw.col(name).eq(value)`, and will be implicitly |
1659 | 1657 | joined with the other filter conditions using &. |
@@ -1695,7 +1693,12 @@ def filter( |
1695 | 1693 | foo bar ham |
1696 | 1694 | 1 2 7 b |
1697 | 1695 | """ |
1698 | | - return super().filter(*predicates, **constraints) |
| 1696 | + impl = self.implementation |
| 1697 | + parsed_predicates = ( |
| 1698 | + self._series.from_iterable("", p, backend=impl) if is_list_of(p, bool) else p |
| 1699 | + for p in predicates |
| 1700 | + ) |
| 1701 | + return super().filter(*parsed_predicates, **constraints) |
1699 | 1702 |
|
1700 | 1703 | @overload |
1701 | 1704 | def group_by( |
@@ -2850,15 +2853,14 @@ def unique( |
2850 | 2853 | ) |
2851 | 2854 |
|
2852 | 2855 | def filter( |
2853 | | - self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any |
| 2856 | + self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any |
2854 | 2857 | ) -> Self: |
2855 | 2858 | r"""Filter the rows in the LazyFrame based on a predicate expression. |
2856 | 2859 |
|
2857 | 2860 | The original order of the remaining rows is preserved. |
2858 | 2861 |
|
2859 | 2862 | Arguments: |
2860 | | - *predicates: Expression that evaluates to a boolean Series. Can |
2861 | | - also be a (single!) boolean list. |
| 2863 | + *predicates: Expression(s) that evaluates to a boolean Series. |
2862 | 2864 | **constraints: Column filters; use `name = value` to filter columns by the supplied value. |
2863 | 2865 | Each constraint will behave the same as `nw.col(name).eq(value)`, and will be implicitly |
2864 | 2866 | joined with the other filter conditions using &. |
@@ -2924,13 +2926,12 @@ def filter( |
2924 | 2926 | └───────┴───────┴─────────┘ |
2925 | 2927 | <BLANKLINE> |
2926 | 2928 | """ |
2927 | | - if ( |
2928 | | - len(predicates) == 1 and is_list_of(predicates[0], bool) and not constraints |
2929 | | - ): # pragma: no cover |
| 2929 | + predicates_ = tuple(tuple(p) if is_iterator(p) else p for p in predicates) |
| 2930 | + if predicates_contains_list_of_bool(predicates_): |
2930 | 2931 | msg = "`LazyFrame.filter` is not supported with Python boolean masks - use expressions instead." |
2931 | 2932 | raise TypeError(msg) |
2932 | 2933 |
|
2933 | | - return super().filter(*predicates, **constraints) |
| 2934 | + return super().filter(*predicates_, **constraints) |
2934 | 2935 |
|
2935 | 2936 | def sink_parquet(self, file: str | Path | BytesIO) -> None: |
2936 | 2937 | """Write LazyFrame to Parquet file. |
|
0 commit comments