Skip to content

Commit 88501c6

Browse files
use initial instead + fix test for non-infer mode
1 parent 6a32c83 commit 88501c6

File tree

5 files changed

+25
-19
lines changed

5 files changed

+25
-19
lines changed

pandas/_libs/groupby.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def group_sum(
6767
result_mask: np.ndarray | None = ...,
6868
min_count: int = ...,
6969
is_datetimelike: bool = ...,
70-
is_string: bool = ...,
70+
initial: object = ...,
7171
skipna: bool = ...,
7272
) -> None: ...
7373
def group_prod(

pandas/_libs/groupby.pyx

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def group_sum(
707707
uint8_t[:, ::1] result_mask=None,
708708
Py_ssize_t min_count=0,
709709
bint is_datetimelike=False,
710-
bint is_string=False,
710+
object initial=0,
711711
bint skipna=True,
712712
) -> None:
713713
"""
@@ -726,13 +726,15 @@ def group_sum(
726726
raise ValueError("len(index) != len(labels)")
727727

728728
nobs = np.zeros((<object>out).shape, dtype=np.int64)
729-
# the below is equivalent to `np.zeros_like(out)` but faster
730-
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
731-
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
732-
733-
if is_string:
734-
# for strings start with empty string instead of 0 as `initial` value
735-
sumx = np.full((<object>out).shape, "", dtype=object)
729+
if initial == 0:
730+
# the below is equivalent to `np.zeros_like(out)` but faster
731+
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
732+
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
733+
else:
734+
# in practice this path is only taken for strings to use empty string as initial
735+
assert sum_t is object
736+
sumx = np.full((<object>out).shape, initial, dtype=object)
737+
# object code path does not use `compensation`
736738

737739
N, K = (<object>values).shape
738740
if uses_mask:

pandas/core/arrays/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2608,7 +2608,7 @@ def _groupby_op(
26082608
kind = WrappedCythonOp.get_kind_from_how(how)
26092609
op = WrappedCythonOp(how=how, kind=kind, has_dropped_na=has_dropped_na)
26102610

2611-
is_string = False
2611+
initial = 0
26122612
# GH#43682
26132613
if isinstance(self.dtype, StringDtype):
26142614
# StringArray
@@ -2633,7 +2633,7 @@ def _groupby_op(
26332633

26342634
arr = self
26352635
if op.how == "sum":
2636-
is_string = True
2636+
initial = ""
26372637
# https://github.com/pandas-dev/pandas/issues/60229
26382638
# All NA should result in the empty string.
26392639
assert "skipna" in kwargs
@@ -2651,7 +2651,7 @@ def _groupby_op(
26512651
ngroups=ngroups,
26522652
comp_ids=ids,
26532653
mask=None,
2654-
is_string=is_string,
2654+
initial=initial,
26552655
**kwargs,
26562656
)
26572657

pandas/core/groupby/ops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools
1313
from typing import (
1414
TYPE_CHECKING,
15+
Any,
1516
Generic,
1617
final,
1718
)
@@ -319,7 +320,7 @@ def _cython_op_ndim_compat(
319320
comp_ids: np.ndarray,
320321
mask: npt.NDArray[np.bool_] | None = None,
321322
result_mask: npt.NDArray[np.bool_] | None = None,
322-
is_string: bool = False,
323+
initial: Any = 0,
323324
**kwargs,
324325
) -> np.ndarray:
325326
if values.ndim == 1:
@@ -336,7 +337,7 @@ def _cython_op_ndim_compat(
336337
comp_ids=comp_ids,
337338
mask=mask,
338339
result_mask=result_mask,
339-
is_string=is_string,
340+
initial=initial,
340341
**kwargs,
341342
)
342343
if res.shape[0] == 1:
@@ -352,7 +353,7 @@ def _cython_op_ndim_compat(
352353
comp_ids=comp_ids,
353354
mask=mask,
354355
result_mask=result_mask,
355-
is_string=is_string,
356+
initial=initial,
356357
**kwargs,
357358
)
358359

@@ -366,7 +367,7 @@ def _call_cython_op(
366367
comp_ids: np.ndarray,
367368
mask: npt.NDArray[np.bool_] | None,
368369
result_mask: npt.NDArray[np.bool_] | None,
369-
is_string: bool = False,
370+
initial: Any = 0,
370371
**kwargs,
371372
) -> np.ndarray: # np.ndarray[ndim=2]
372373
orig_values = values
@@ -427,7 +428,7 @@ def _call_cython_op(
427428
if self.how == "sum":
428429
# pass in through kwargs only for sum (other functions don't have
429430
# the keyword)
430-
kwargs["is_string"] = is_string
431+
kwargs["initial"] = initial
431432
func(
432433
out=result,
433434
counts=counts,

pandas/tests/groupby/test_categorical.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def test_apply(ordered):
325325
tm.assert_series_equal(result, expected)
326326

327327

328-
def test_observed(observed):
328+
def test_observed(observed, using_infer_string):
329329
# multiple groupers, don't re-expand the output space
330330
# of the grouper
331331
# gh-14942 (implement)
@@ -360,7 +360,10 @@ def test_observed(observed):
360360
result = gb.sum()
361361
if not observed:
362362
expected = cartesian_product_for_groupers(
363-
expected, [cat1, cat2], list("AB"), fill_value={"values": 0, "C": ""}
363+
expected,
364+
[cat1, cat2],
365+
list("AB"),
366+
fill_value={"values": 0, "C": ""} if using_infer_string else 0,
364367
)
365368

366369
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)