Skip to content

Commit 73f8a84

Browse files
committed
[SPARK-53184][PS] melt when "value" has MultiIndex column labels
### What changes were proposed in this pull request? Fix the issue when [SPARK-53184][PS] `melt` when "value" has MultiIndex column labels. ### Why are the changes needed? Ensure pandas on spark works well under ANSI ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51914 from xinrong-meng/melt_multi. Authored-by: Xinrong Meng <xinrong@apache.org> Signed-off-by: Xinrong Meng <xinrong@apache.org>
1 parent 1bc8ce0 commit 73f8a84

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

python/pyspark/pandas/frame.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10638,15 +10638,20 @@ def melt(
1063810638
else:
1063910639
var_name = [var_name] # type: ignore[list-item]
1064010640

10641-
value_col_types = [
10642-
self._internal.spark_column_for(label).expr.dataType for label in value_vars
10643-
]
10644-
# If any value column is of StringType, cast all value columns to StringType to avoid
10645-
# ANSI mode errors during explode - mixing strings and integers.
10646-
string_cast_required_type = (
10647-
StringType() if any(isinstance(t, StringType) for t in value_col_types) else None
10648-
)
1064910641
use_cast = is_ansi_mode_enabled(self._internal.spark_frame.sparkSession)
10642+
string_cast_required_type = None
10643+
if use_cast:
10644+
field_by_label = {
10645+
label: field
10646+
for label, field in zip(self._internal.column_labels, self._internal.data_fields)
10647+
}
10648+
10649+
value_col_types = [field_by_label[label].spark_type for label in value_vars]
10650+
# If any value column is of StringType, cast all value columns to StringType to avoid
10651+
# ANSI mode errors during explode - mixing strings and integers.
10652+
string_cast_required_type = (
10653+
StringType() if any(isinstance(t, StringType) for t in value_col_types) else None
10654+
)
1065010655

1065110656
pairs = F.explode(
1065210657
F.array(
@@ -13824,16 +13829,12 @@ def _test() -> None:
1382413829
import uuid
1382513830
from pyspark.sql import SparkSession
1382613831
import pyspark.pandas.frame
13827-
from pyspark.testing.utils import is_ansi_mode_test
1382813832

1382913833
os.chdir(os.environ["SPARK_HOME"])
1383013834

1383113835
globs = pyspark.pandas.frame.__dict__.copy()
1383213836
globs["ps"] = pyspark.pandas
1383313837

13834-
if is_ansi_mode_test:
13835-
del pyspark.pandas.frame.DataFrame.melt.__doc__
13836-
1383713838
spark = (
1383813839
SparkSession.builder.master("local[4]").appName("pyspark.pandas.frame tests").getOrCreate()
1383913840
)

python/pyspark/pandas/namespace.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3879,7 +3879,6 @@ def _test() -> None:
38793879
from pyspark.sql import SparkSession
38803880
import pyspark.pandas.namespace
38813881
from pandas.util.version import Version
3882-
from pyspark.testing.utils import is_ansi_mode_test
38833882

38843883
os.chdir(os.environ["SPARK_HOME"])
38853884

@@ -3893,9 +3892,6 @@ def _test() -> None:
38933892
globs["ps"] = pyspark.pandas
38943893
globs["sf"] = F
38953894

3896-
if is_ansi_mode_test:
3897-
del pyspark.pandas.namespace.melt.__doc__
3898-
38993895
spark = (
39003896
SparkSession.builder.master("local[4]")
39013897
.appName("pyspark.pandas.namespace tests")

0 commit comments

Comments
 (0)