Skip to content

Commit 4b650d0

Browse files
committed
Updated with reviewer suggestions
1 parent c05b1a7 commit 4b650d0

File tree

3 files changed

+65
-52
lines changed

3 files changed

+65
-52
lines changed

pandas/core/apply.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,8 @@ def apply(
216216

217217
NumbaExecutionEngine.check_numba_support(func)
218218

219-
# normalize axis values
220-
if axis in (0, "index"):
221-
axis = 0
222-
else:
223-
axis = 1
219+
if not isinstance(data, np.ndarray):
220+
axis = data._get_axis_number(cast(Axis, axis))
224221

225222
# check for data typing
226223
if not isinstance(data, np.ndarray):
@@ -230,7 +227,7 @@ def apply(
230227
engine_kwargs = extract_numba_options(decorator)
231228

232229
NumbaExecutionEngine.validate_values_for_numba_raw_false(
233-
data, get_jit_arguments(engine_kwargs)
230+
data, **get_jit_arguments(engine_kwargs)
234231
)
235232

236233
return NumbaExecutionEngine.apply_raw_false(
@@ -288,9 +285,29 @@ def apply_raw_false(
288285
Series,
289286
)
290287

291-
results = NumbaExecutionEngine.apply_with_numba(
292-
data, func, args, kwargs, decorator, axis
288+
func = cast(Callable, func)
289+
args, kwargs = prepare_function_arguments(
290+
func, args, kwargs, num_required_args=1
293291
)
292+
nb_func = NumbaExecutionEngine.generate_numba_apply_func(func, axis, decorator)
293+
294+
from pandas.core._numba.extensions import set_numba_data
295+
296+
# Convert from numba dict to regular dict
297+
# Our isinstance checks in the df constructor don't pass for numbas typed dict
298+
299+
if axis == 0:
300+
col_names_index = data.index
301+
result_index = data.columns
302+
else:
303+
col_names_index = data.columns
304+
result_index = data.index
305+
306+
with (
307+
set_numba_data(result_index) as index,
308+
set_numba_data(col_names_index) as columns,
309+
):
310+
results = dict(nb_func(data.values, columns, index, *args))
294311

295312
if results:
296313
sample = next(iter(results.values()))
@@ -306,11 +323,14 @@ def apply_raw_false(
306323

307324
@staticmethod
308325
def validate_values_for_numba_raw_false(
309-
data: Series | DataFrame, engine_kwargs: dict[str, bool]
326+
data: Series | DataFrame,
327+
nopython: bool | None = None,
328+
nogil: bool | None = None,
329+
parallel: bool | None = None,
310330
) -> None:
311331
from pandas import Series
312332

313-
if engine_kwargs.get("parallel", False):
333+
if parallel:
314334
raise NotImplementedError(
315335
"Parallel apply is not supported when raw=False and engine='numba'"
316336
)
@@ -376,34 +396,6 @@ def numba_func(values, col_names_index, index, *args):
376396

377397
return numba_func
378398

379-
@staticmethod
380-
def apply_with_numba(data, func, args, kwargs, decorator, axis=0) -> dict[int, Any]:
381-
func = cast(Callable, func)
382-
args, kwargs = prepare_function_arguments(
383-
func, args, kwargs, num_required_args=1
384-
)
385-
nb_func = NumbaExecutionEngine.generate_numba_apply_func(func, axis, decorator)
386-
387-
from pandas.core._numba.extensions import set_numba_data
388-
389-
# Convert from numba dict to regular dict
390-
# Our isinstance checks in the df constructor don't pass for numbas typed dict
391-
392-
if axis == 0 or axis == "index":
393-
col_names_index = data.index
394-
result_index = data.columns
395-
else:
396-
col_names_index = data.columns
397-
result_index = data.index
398-
399-
with (
400-
set_numba_data(result_index) as index,
401-
set_numba_data(col_names_index) as columns,
402-
):
403-
res = dict(nb_func(data.values, columns, index, *args))
404-
405-
return res
406-
407399

408400
def frame_apply(
409401
obj: DataFrame,

pandas/core/util/numba_.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -152,20 +152,37 @@ def prepare_function_arguments(
152152

153153
def extract_numba_options(decorator: Callable) -> dict:
154154
"""
155-
Extract targetoptions from a numba.jit decorator
155+
Extract the `targetoptions` dictionary from a numba.jit decorator.
156+
157+
The `targetoptions` attribute stores the keyword arguments
158+
passed to the `@numba.jit` decorator when it is created.
159+
160+
This function returns a dictionary with the following keys,
161+
if present in the decorator:
162+
- nopython
163+
- nogil
164+
- parallel
165+
166+
Parameters
167+
----------
168+
decorator : Callable
169+
A numba.jit decorated function or a numba dispatcher object.
170+
171+
Returns
172+
-------
173+
dict
174+
A dictionary with the extracted numba compilation options.
156175
"""
157-
try:
158-
closure = decorator.__closure__
159-
if closure is None:
160-
return {}
161-
freevars = decorator.__code__.co_freevars
162-
if "targetoptions" not in freevars:
163-
return {}
164-
idx = freevars.index("targetoptions")
165-
cell = closure[idx]
166-
targetoptions = cell.cell_contents
167-
if isinstance(targetoptions, dict):
168-
return targetoptions
176+
closure = decorator.__closure__
177+
if closure is None:
169178
return {}
170-
except Exception:
179+
freevars = decorator.__code__.co_freevars
180+
if "targetoptions" not in freevars:
171181
return {}
182+
idx = freevars.index("targetoptions")
183+
cell = closure[idx]
184+
targetoptions = cell.cell_contents
185+
if isinstance(targetoptions, dict):
186+
relevant_keys = {"nopython", "nogil", "parallel"}
187+
return {k: v for k, v in targetoptions.items() if k in relevant_keys}
188+
return {}

pandas/tests/apply/test_numba.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,7 @@ def test_extract_numba_options_from_user_decorated_function(jit_args):
141141
extracted = extract_numba_options(numba.jit(**jit_args))
142142
for k, v in jit_args.items():
143143
assert extracted.get(k) == v
144+
145+
extracted = extract_numba_options(numba.njit(**jit_args))
146+
for k, v in jit_args.items():
147+
assert extracted.get(k) == v

0 commit comments

Comments
 (0)