diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 016b30e215d02..406bf04322177 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -4135,6 +4135,33 @@ def isin(self, values, level=None) -> npt.NDArray[np.bool_]: # base class "Index" defined the type as "Callable[[Index, Any, bool], Any]") rename = Index.set_names # type: ignore[assignment] + def difference(self, other, sort=None): + """ + Return a new MultiIndex with elements in self that are not in other. + Fixed to work with pyarrow-backed Timestamps. + """ + if isinstance(other, type(self)): + # Convert pyarrow-backed Timestamps to pandas Timestamps for comparison + self_arrays = [level.to_pandas() if hasattr(level, "to_pandas") else level + for level in self.levels] + other_arrays = [level.to_pandas() if hasattr(level, "to_pandas") else level + for level in other.levels] + self_conv = pd.MultiIndex.from_arrays(self_arrays, names=self.names) + other_conv = pd.MultiIndex.from_arrays(other_arrays, names=other.names) + result = self_conv.difference(other_conv, sort=sort) + # Preserve pyarrow dtypes if present + for i, level in enumerate(self.levels): + if getattr(level, "dtype", None) == "timestamp[ns][pyarrow]": + result = pd.MultiIndex.from_arrays( + [pd.Series(arr, dtype="timestamp[ns][pyarrow]") if i==idx else arr + for idx, arr in enumerate(result.levels)], + names=result.names + ) + return result + else: + return super(type(self), self).difference(other, sort=sort) + + # --------------------------------------------------------------- # Arithmetic/Numeric Methods - Disabled diff --git a/pandas/tests/indexes/multi/test_timestamp.py b/pandas/tests/indexes/multi/test_timestamp.py new file mode 100644 index 0000000000000..15a846b8d1b78 --- /dev/null +++ b/pandas/tests/indexes/multi/test_timestamp.py @@ -0,0 +1,25 @@ +import pandas as pd +import pytest + +pytest.importorskip("pyarrow") + +def test_difference_with_pyarrow_timestamp(): + dates = pd.Series( + ["2024-01-01", "2024-01-02"], dtype="timestamp[ns][pyarrow]" + ) + ids = [1, 2] + + mi = pd.MultiIndex.from_arrays([ids, dates], names=["id", "date"]) + to_remove = mi[:1] + + result = mi.difference(to_remove) + + expected_dates = pd.Series( + ["2024-01-02"], dtype="timestamp[ns][pyarrow]" + ) + expected_ids = [2] + expected = pd.MultiIndex.from_arrays( + [expected_ids, expected_dates], names=["id", "date"] + ) + + pd.testing.assert_index_equal(result, expected)