Skip to content

Commit 551a8e8

Browse files
authored
REF: remove unnecessary case from maybe_downcast_to_dtype (#62166)
1 parent 728be93 commit 551a8e8

File tree

2 files changed

+6
-75
lines changed

2 files changed

+6
-75
lines changed

pandas/core/dtypes/cast.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -246,18 +246,16 @@ def _disallow_mismatched_datetimelike(value, dtype: DtypeObj) -> None:
246246

247247

248248
@overload
249-
def maybe_downcast_to_dtype(
250-
result: np.ndarray, dtype: str | np.dtype
251-
) -> np.ndarray: ...
249+
def maybe_downcast_to_dtype(result: np.ndarray, dtype: np.dtype) -> np.ndarray: ...
252250

253251

254252
@overload
255253
def maybe_downcast_to_dtype(
256-
result: ExtensionArray, dtype: str | np.dtype
257-
) -> ArrayLike: ...
254+
result: ExtensionArray, dtype: np.dtype
255+
) -> ExtensionArray: ...
258256

259257

260-
def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLike:
258+
def maybe_downcast_to_dtype(result: ArrayLike, dtype: np.dtype) -> ArrayLike:
261259
"""
262260
try to cast to the specified dtype (e.g. convert back to bool/int
263261
or could be an astype of float64->float32
@@ -266,30 +264,6 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
266264
result = result._values
267265
do_round = False
268266

269-
if isinstance(dtype, str):
270-
if dtype == "infer":
271-
inferred_type = lib.infer_dtype(result, skipna=False)
272-
if inferred_type == "boolean":
273-
dtype = "bool"
274-
elif inferred_type == "integer":
275-
dtype = "int64"
276-
elif inferred_type == "datetime64":
277-
dtype = "datetime64[ns]"
278-
elif inferred_type in ["timedelta", "timedelta64"]:
279-
dtype = "timedelta64[ns]"
280-
281-
# try to upcast here
282-
elif inferred_type == "floating":
283-
dtype = "int64"
284-
if issubclass(result.dtype.type, np.number):
285-
do_round = True
286-
287-
else:
288-
# TODO: complex? what if result is already non-object?
289-
dtype = "object"
290-
291-
dtype = np.dtype(dtype)
292-
293267
if not isinstance(dtype, np.dtype):
294268
# enforce our signature annotation
295269
raise TypeError(dtype) # pragma: no cover

pandas/tests/dtypes/cast/test_downcast.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,20 @@
77

88
from pandas import (
99
Series,
10-
Timedelta,
1110
)
1211
import pandas._testing as tm
1312

1413

1514
@pytest.mark.parametrize(
1615
"arr,dtype,expected",
1716
[
18-
(
19-
np.array([8.5, 8.6, 8.7, 8.8, 8.9999999999995]),
20-
"infer",
21-
np.array([8.5, 8.6, 8.7, 8.8, 8.9999999999995]),
22-
),
23-
(
24-
np.array([8.0, 8.0, 8.0, 8.0, 8.9999999999995]),
25-
"infer",
26-
np.array([8, 8, 8, 8, 9], dtype=np.int64),
27-
),
28-
(
29-
np.array([8.0, 8.0, 8.0, 8.0, 9.0000000000005]),
30-
"infer",
31-
np.array([8, 8, 8, 8, 9], dtype=np.int64),
32-
),
3317
(
3418
# This is a judgement call, but we do _not_ downcast Decimal
3519
# objects
3620
np.array([decimal.Decimal("0.0")]),
37-
"int64",
21+
np.dtype("int64"),
3822
np.array([decimal.Decimal("0.0")]),
3923
),
40-
(
41-
# GH#45837
42-
np.array([Timedelta(days=1), Timedelta(days=2)], dtype=object),
43-
"infer",
44-
np.array([1, 2], dtype="m8[D]").astype("m8[ns]"),
45-
),
46-
# TODO: similar for dt64, dt64tz, Period, Interval?
4724
],
4825
)
4926
def test_downcast(arr, expected, dtype):
@@ -60,26 +37,6 @@ def test_downcast_booleans():
6037
tm.assert_numpy_array_equal(result, expected)
6138

6239

63-
def test_downcast_conversion_no_nan(any_real_numpy_dtype):
64-
dtype = any_real_numpy_dtype
65-
expected = np.array([1, 2])
66-
arr = np.array([1.0, 2.0], dtype=dtype)
67-
68-
result = maybe_downcast_to_dtype(arr, "infer")
69-
tm.assert_almost_equal(result, expected, check_dtype=False)
70-
71-
72-
def test_downcast_conversion_nan(float_numpy_dtype):
73-
dtype = float_numpy_dtype
74-
data = [1.0, 2.0, np.nan]
75-
76-
expected = np.array(data, dtype=dtype)
77-
arr = np.array(data, dtype=dtype)
78-
79-
result = maybe_downcast_to_dtype(arr, "infer")
80-
tm.assert_almost_equal(result, expected)
81-
82-
8340
def test_downcast_conversion_empty(any_real_numpy_dtype):
8441
dtype = any_real_numpy_dtype
8542
arr = np.array([], dtype=dtype)
@@ -89,7 +46,7 @@ def test_downcast_conversion_empty(any_real_numpy_dtype):
8946

9047
@pytest.mark.parametrize("klass", [np.datetime64, np.timedelta64])
9148
def test_datetime_likes_nan(klass):
92-
dtype = klass.__name__ + "[ns]"
49+
dtype = np.dtype(klass.__name__ + "[ns]")
9350
arr = np.array([1, 2, np.nan])
9451

9552
exp = np.array([1, 2, klass("NaT")], dtype)

0 commit comments

Comments
 (0)