Skip to content

Commit ddbc8ec

Browse files
BUG: fix fill value for gouped sum in case of unobserved categories for string dtype (empty string instead of 0)
1 parent 6a6a1ba commit ddbc8ec

File tree

5 files changed

+26
-5
lines changed

5 files changed

+26
-5
lines changed

pandas/_libs/groupby.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def group_sum(
6767
result_mask: np.ndarray | None = ...,
6868
min_count: int = ...,
6969
is_datetimelike: bool = ...,
70+
is_string: bool = ...,
7071
skipna: bool = ...,
7172
) -> None: ...
7273
def group_prod(

pandas/_libs/groupby.pyx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,7 @@ def group_sum(
707707
uint8_t[:, ::1] result_mask=None,
708708
Py_ssize_t min_count=0,
709709
bint is_datetimelike=False,
710+
bint is_string=False,
710711
bint skipna=True,
711712
) -> None:
712713
"""
@@ -729,6 +730,10 @@ def group_sum(
729730
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
730731
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
731732

733+
if is_string:
734+
# for strings start with empty string instead of 0 as `initial` value
735+
sumx = np.full((<object>out).shape, "", dtype=object)
736+
732737
N, K = (<object>values).shape
733738
if uses_mask:
734739
nan_val = 0

pandas/core/arrays/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2608,6 +2608,7 @@ def _groupby_op(
26082608
kind = WrappedCythonOp.get_kind_from_how(how)
26092609
op = WrappedCythonOp(how=how, kind=kind, has_dropped_na=has_dropped_na)
26102610

2611+
is_string = False
26112612
# GH#43682
26122613
if isinstance(self.dtype, StringDtype):
26132614
# StringArray
@@ -2632,6 +2633,7 @@ def _groupby_op(
26322633

26332634
arr = self
26342635
if op.how == "sum":
2636+
is_string = True
26352637
# https://github.com/pandas-dev/pandas/issues/60229
26362638
# All NA should result in the empty string.
26372639
assert "skipna" in kwargs
@@ -2649,6 +2651,7 @@ def _groupby_op(
26492651
ngroups=ngroups,
26502652
comp_ids=ids,
26512653
mask=None,
2654+
is_string=is_string,
26522655
**kwargs,
26532656
)
26542657

pandas/core/groupby/ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def _cython_op_ndim_compat(
319319
comp_ids: np.ndarray,
320320
mask: npt.NDArray[np.bool_] | None = None,
321321
result_mask: npt.NDArray[np.bool_] | None = None,
322+
is_string: bool = False,
322323
**kwargs,
323324
) -> np.ndarray:
324325
if values.ndim == 1:
@@ -335,6 +336,7 @@ def _cython_op_ndim_compat(
335336
comp_ids=comp_ids,
336337
mask=mask,
337338
result_mask=result_mask,
339+
is_string=is_string,
338340
**kwargs,
339341
)
340342
if res.shape[0] == 1:
@@ -350,6 +352,7 @@ def _cython_op_ndim_compat(
350352
comp_ids=comp_ids,
351353
mask=mask,
352354
result_mask=result_mask,
355+
is_string=is_string,
353356
**kwargs,
354357
)
355358

@@ -363,6 +366,7 @@ def _call_cython_op(
363366
comp_ids: np.ndarray,
364367
mask: npt.NDArray[np.bool_] | None,
365368
result_mask: npt.NDArray[np.bool_] | None,
369+
is_string: bool = False,
366370
**kwargs,
367371
) -> np.ndarray: # np.ndarray[ndim=2]
368372
orig_values = values
@@ -420,6 +424,10 @@ def _call_cython_op(
420424
"sum",
421425
"median",
422426
]:
427+
if self.how == "sum":
428+
# pass in through kwargs only for sum (other functions don't have
429+
# the keyword)
430+
kwargs["is_string"] = is_string
423431
func(
424432
out=result,
425433
counts=counts,

pandas/tests/groupby/test_categorical.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ def f(a):
3232
return a
3333

3434
index = MultiIndex.from_product(map(f, args), names=names)
35+
if isinstance(fill_value, dict):
36+
# fill_value is a dict mapping column names to fill values
37+
# -> reindex column by column (reindex itself does not support this)
38+
res = {}
39+
for col in result.columns:
40+
res[col] = result[col].reindex(index, fill_value=fill_value[col])
41+
return DataFrame(res, index=index).sort_index()
42+
3543
return result.reindex(index, fill_value=fill_value).sort_index()
3644

3745

@@ -325,10 +333,6 @@ def test_observed(request, using_infer_string, observed):
325333
# gh-8138 (back-compat)
326334
# gh-8869
327335

328-
if using_infer_string and not observed:
329-
# TODO(infer_string) this fails with filling the string column with 0
330-
request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)"))
331-
332336
cat1 = Categorical(["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True)
333337
cat2 = Categorical(["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True)
334338
df = DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]})
@@ -356,7 +360,7 @@ def test_observed(request, using_infer_string, observed):
356360
result = gb.sum()
357361
if not observed:
358362
expected = cartesian_product_for_groupers(
359-
expected, [cat1, cat2], list("AB"), fill_value=0
363+
expected, [cat1, cat2], list("AB"), fill_value={"values": 0, "C": ""}
360364
)
361365

362366
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)