Skip to content

Commit cb7b334

Browse files
authored
ENH: EA._cast_pointwise_result (#62105)
1 parent 7cc093f commit cb7b334

File tree

24 files changed

+287
-321
lines changed

24 files changed

+287
-321
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,73 @@ def _from_sequence_of_strings(
392392
)
393393
return cls._from_sequence(scalars, dtype=pa_type, copy=copy)
394394

395+
def _cast_pointwise_result(self, values) -> ArrayLike:
396+
if len(values) == 0:
397+
# Retain our dtype
398+
return self[:0].copy()
399+
400+
try:
401+
arr = pa.array(values, from_pandas=True)
402+
except (ValueError, TypeError):
403+
# e.g. test_by_column_values_with_same_starting_value with nested
404+
# values, one entry of which is an ArrowStringArray
405+
# or test_agg_lambda_complex128_dtype_conversion for complex values
406+
return super()._cast_pointwise_result(values)
407+
408+
if pa.types.is_duration(arr.type):
409+
# workaround for https://github.com/apache/arrow/issues/40620
410+
result = ArrowExtensionArray._from_sequence(values)
411+
if pa.types.is_duration(self._pa_array.type):
412+
result = result.astype(self.dtype) # type: ignore[assignment]
413+
elif pa.types.is_timestamp(self._pa_array.type):
414+
# Try to retain original unit
415+
new_dtype = ArrowDtype(pa.duration(self._pa_array.type.unit))
416+
try:
417+
result = result.astype(new_dtype) # type: ignore[assignment]
418+
except ValueError:
419+
pass
420+
elif pa.types.is_date64(self._pa_array.type):
421+
# Try to match unit we get on non-pointwise op
422+
dtype = ArrowDtype(pa.duration("ms"))
423+
result = result.astype(dtype) # type: ignore[assignment]
424+
elif pa.types.is_date(self._pa_array.type):
425+
# Try to match unit we get on non-pointwise op
426+
dtype = ArrowDtype(pa.duration("s"))
427+
result = result.astype(dtype) # type: ignore[assignment]
428+
return result
429+
430+
elif pa.types.is_date(arr.type) and pa.types.is_date(self._pa_array.type):
431+
arr = arr.cast(self._pa_array.type)
432+
elif pa.types.is_time(arr.type) and pa.types.is_time(self._pa_array.type):
433+
arr = arr.cast(self._pa_array.type)
434+
elif pa.types.is_decimal(arr.type) and pa.types.is_decimal(self._pa_array.type):
435+
arr = arr.cast(self._pa_array.type)
436+
elif pa.types.is_integer(arr.type) and pa.types.is_integer(self._pa_array.type):
437+
try:
438+
arr = arr.cast(self._pa_array.type)
439+
except pa.lib.ArrowInvalid:
440+
# e.g. test_combine_add if we can't cast
441+
pass
442+
elif pa.types.is_floating(arr.type) and pa.types.is_floating(
443+
self._pa_array.type
444+
):
445+
try:
446+
arr = arr.cast(self._pa_array.type)
447+
except pa.lib.ArrowInvalid:
448+
# e.g. test_combine_add if we can't cast
449+
pass
450+
451+
if isinstance(self.dtype, StringDtype):
452+
if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
453+
# ArrowStringArrayNumpySemantics
454+
return type(self)(arr).astype(self.dtype)
455+
if self.dtype.na_value is np.nan:
456+
# ArrowEA has different semantics, so we return numpy-based
457+
# result instead
458+
return super()._cast_pointwise_result(values)
459+
return ArrowExtensionArray(arr)
460+
return type(self)(arr)
461+
395462
@classmethod
396463
def _box_pa(
397464
cls, value, pa_type: pa.DataType | None = None

pandas/core/arrays/base.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
cast,
2020
overload,
2121
)
22-
import warnings
2322

2423
import numpy as np
2524

@@ -35,13 +34,11 @@
3534
Substitution,
3635
cache_readonly,
3736
)
38-
from pandas.util._exceptions import find_stack_level
3937
from pandas.util._validators import (
4038
validate_bool_kwarg,
4139
validate_insert_loc,
4240
)
4341

44-
from pandas.core.dtypes.cast import maybe_cast_pointwise_result
4542
from pandas.core.dtypes.common import (
4643
is_list_like,
4744
is_scalar,
@@ -89,7 +86,6 @@
8986
AstypeArg,
9087
AxisInt,
9188
Dtype,
92-
DtypeObj,
9389
FillnaOptions,
9490
InterpolateOptions,
9591
NumpySorter,
@@ -311,38 +307,6 @@ def _from_sequence(
311307
"""
312308
raise AbstractMethodError(cls)
313309

314-
@classmethod
315-
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
316-
"""
317-
Strict analogue to _from_sequence, allowing only sequences of scalars
318-
that should be specifically inferred to the given dtype.
319-
320-
Parameters
321-
----------
322-
scalars : sequence
323-
dtype : ExtensionDtype
324-
325-
Raises
326-
------
327-
TypeError or ValueError
328-
329-
Notes
330-
-----
331-
This is called in a try/except block when casting the result of a
332-
pointwise operation.
333-
"""
334-
try:
335-
return cls._from_sequence(scalars, dtype=dtype, copy=False)
336-
except (ValueError, TypeError):
337-
raise
338-
except Exception:
339-
warnings.warn(
340-
"_from_scalars should only raise ValueError or TypeError. "
341-
"Consider overriding _from_scalars where appropriate.",
342-
stacklevel=find_stack_level(),
343-
)
344-
raise
345-
346310
@classmethod
347311
def _from_sequence_of_strings(
348312
cls, strings, *, dtype: ExtensionDtype, copy: bool = False
@@ -371,9 +335,6 @@ def _from_sequence_of_strings(
371335
from a sequence of scalars.
372336
api.extensions.ExtensionArray._from_factorized : Reconstruct an ExtensionArray
373337
after factorization.
374-
api.extensions.ExtensionArray._from_scalars : Strict analogue to _from_sequence,
375-
allowing only sequences of scalars that should be specifically inferred to
376-
the given dtype.
377338
378339
Examples
379340
--------
@@ -416,6 +377,14 @@ def _from_factorized(cls, values, original):
416377
"""
417378
raise AbstractMethodError(cls)
418379

380+
def _cast_pointwise_result(self, values) -> ArrayLike:
381+
"""
382+
Cast the result of a pointwise operation (e.g. Series.map) to an
383+
array, preserve dtype_backend if possible.
384+
"""
385+
values = np.asarray(values, dtype=object)
386+
return lib.maybe_convert_objects(values, convert_non_numeric=True)
387+
419388
# ------------------------------------------------------------------------
420389
# Must be a Sequence
421390
# ------------------------------------------------------------------------
@@ -2842,7 +2811,7 @@ def _maybe_convert(arr):
28422811
# https://github.com/pandas-dev/pandas/issues/22850
28432812
# We catch all regular exceptions here, and fall back
28442813
# to an ndarray.
2845-
res = maybe_cast_pointwise_result(arr, self.dtype, same_dtype=False)
2814+
res = self._cast_pointwise_result(arr)
28462815
if not isinstance(res, type(self)):
28472816
# exception raised in _from_sequence; ensure we have ndarray
28482817
res = np.asarray(arr)

pandas/core/arrays/categorical.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@
103103
AstypeArg,
104104
AxisInt,
105105
Dtype,
106-
DtypeObj,
107106
NpDtype,
108107
Ordered,
109108
Shape,
@@ -529,20 +528,12 @@ def _from_sequence(
529528
) -> Self:
530529
return cls(scalars, dtype=dtype, copy=copy)
531530

532-
@classmethod
533-
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
534-
if dtype is None:
535-
# The _from_scalars strictness doesn't make much sense in this case.
536-
raise NotImplementedError
537-
538-
res = cls._from_sequence(scalars, dtype=dtype)
539-
540-
# if there are any non-category elements in scalars, these will be
541-
# converted to NAs in res.
542-
mask = isna(scalars)
543-
if not (mask == res.isna()).all():
544-
# Some non-category element in scalars got converted to NA in res.
545-
raise ValueError
531+
def _cast_pointwise_result(self, values) -> ArrayLike:
532+
res = super()._cast_pointwise_result(values)
533+
cat = type(self)._from_sequence(res, dtype=self.dtype)
534+
if (cat.isna() == isna(res)).all():
535+
# i.e. the conversion was non-lossy
536+
return cat
546537
return res
547538

548539
@overload

pandas/core/arrays/datetimes.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@
8383
from pandas._typing import (
8484
ArrayLike,
8585
DateTimeErrorChoices,
86-
DtypeObj,
8786
IntervalClosedType,
8887
TimeAmbiguous,
8988
TimeNonexistent,
@@ -293,14 +292,6 @@ def _scalar_type(self) -> type[Timestamp]:
293292
_dtype: np.dtype[np.datetime64] | DatetimeTZDtype
294293
_freq: BaseOffset | None = None
295294

296-
@classmethod
297-
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
298-
if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]:
299-
# TODO: require any NAs be valid-for-DTA
300-
# TODO: if dtype is passed, check for tzawareness compat?
301-
raise ValueError
302-
return cls._from_sequence(scalars, dtype=dtype)
303-
304295
@classmethod
305296
def _validate_dtype(cls, values, dtype):
306297
# used in TimeLikeOps.__init__

pandas/core/arrays/masked.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pandas.util._decorators import doc
2727

2828
from pandas.core.dtypes.base import ExtensionDtype
29+
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
2930
from pandas.core.dtypes.common import (
3031
is_bool,
3132
is_integer_dtype,
@@ -147,6 +148,19 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False) -> Self:
147148
values, mask = cls._coerce_to_array(scalars, dtype=dtype, copy=copy)
148149
return cls(values, mask)
149150

151+
def _cast_pointwise_result(self, values) -> ArrayLike:
152+
values = np.asarray(values, dtype=object)
153+
result = lib.maybe_convert_objects(values, convert_to_nullable_dtype=True)
154+
lkind = self.dtype.kind
155+
rkind = result.dtype.kind
156+
if (lkind in "iu" and rkind in "iu") or (lkind == rkind == "f"):
157+
result = cast(BaseMaskedArray, result)
158+
new_data = maybe_downcast_to_dtype(
159+
result._data, dtype=self.dtype.numpy_dtype
160+
)
161+
result = type(result)(new_data, result._mask)
162+
return result
163+
150164
@classmethod
151165
@doc(ExtensionArray._empty)
152166
def _empty(cls, shape: Shape, dtype: ExtensionDtype) -> Self:

pandas/core/arrays/numpy_.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from pandas.compat.numpy import function as nv
1515

1616
from pandas.core.dtypes.astype import astype_array
17-
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
17+
from pandas.core.dtypes.cast import (
18+
construct_1d_object_array_from_listlike,
19+
maybe_downcast_to_dtype,
20+
)
1821
from pandas.core.dtypes.common import pandas_dtype
1922
from pandas.core.dtypes.dtypes import NumpyEADtype
2023
from pandas.core.dtypes.missing import isna
@@ -34,6 +37,7 @@
3437
from collections.abc import Callable
3538

3639
from pandas._typing import (
40+
ArrayLike,
3741
AxisInt,
3842
Dtype,
3943
FillnaOptions,
@@ -145,6 +149,24 @@ def _from_sequence(
145149
result = result.copy()
146150
return cls(result)
147151

152+
def _cast_pointwise_result(self, values) -> ArrayLike:
153+
result = super()._cast_pointwise_result(values)
154+
lkind = self.dtype.kind
155+
rkind = result.dtype.kind
156+
if (
157+
(lkind in "iu" and rkind in "iu")
158+
or (lkind == "f" and rkind == "f")
159+
or (lkind == rkind == "c")
160+
):
161+
result = maybe_downcast_to_dtype(result, self.dtype.numpy_dtype)
162+
elif rkind == "M":
163+
# Ensure potential subsequent .astype(object) doesn't incorrectly
164+
# convert Timestamps to ints
165+
from pandas import array as pd_array
166+
167+
result = pd_array(result, copy=False)
168+
return result
169+
148170
# ------------------------------------------------------------------------
149171
# Data
150172

pandas/core/arrays/sparse/array.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,23 @@ def _from_sequence(
607607
def _from_factorized(cls, values, original) -> Self:
608608
return cls(values, dtype=original.dtype)
609609

610+
def _cast_pointwise_result(self, values):
611+
result = super()._cast_pointwise_result(values)
612+
if result.dtype.kind == self.dtype.kind:
613+
try:
614+
# e.g. test_groupby_agg_extension
615+
res = type(self)._from_sequence(result, dtype=self.dtype)
616+
if ((res == result) | (isna(result) & res.isna())).all():
617+
# This does not hold for e.g.
618+
# test_arith_frame_with_scalar[0-__truediv__]
619+
return res
620+
return type(self)._from_sequence(result)
621+
except (ValueError, TypeError):
622+
return type(self)._from_sequence(result)
623+
else:
624+
# e.g. test_combine_le avoid casting bools to Sparse[float64, nan]
625+
return type(self)._from_sequence(result)
626+
610627
# ------------------------------------------------------------------------
611628
# Data
612629
# ------------------------------------------------------------------------

pandas/core/arrays/string_.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,6 @@ def tolist(self) -> list:
412412
return [x.tolist() for x in self]
413413
return list(self.to_numpy())
414414

415-
@classmethod
416-
def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
417-
if lib.infer_dtype(scalars, skipna=True) not in ["string", "empty"]:
418-
# TODO: require any NAs be valid-for-string
419-
raise ValueError
420-
return cls._from_sequence(scalars, dtype=dtype)
421-
422415
def _formatter(self, boxed: bool = False):
423416
formatter = partial(
424417
printing.pprint_thing,
@@ -732,6 +725,13 @@ def _from_sequence_of_strings(
732725
) -> Self:
733726
return cls._from_sequence(strings, dtype=dtype, copy=copy)
734727

728+
def _cast_pointwise_result(self, values) -> ArrayLike:
729+
result = super()._cast_pointwise_result(values)
730+
if isinstance(result.dtype, StringDtype):
731+
# Ensure we retain our same na_value/storage
732+
result = result.astype(self.dtype) # type: ignore[call-overload]
733+
return result
734+
735735
@classmethod
736736
def _empty(cls, shape, dtype) -> StringArray:
737737
values = np.empty(shape, dtype=object)

0 commit comments

Comments
 (0)