Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ac5d83c
first commit
xianzhe-databricks Sep 26, 2025
1e702c9
data source related
xianzhe-databricks Sep 26, 2025
659fe4e
add conf on connect
xianzhe-databricks Sep 26, 2025
e791532
add partial tests
xianzhe-databricks Sep 26, 2025
6787bfc
Update python/pyspark/sql/tests/arrow/test_arrow_binary_as_bytes_udf.py
HyukjinKwon Sep 28, 2025
453298d
add tests for spark connect
xianzhe-databricks Sep 26, 2025
c859ce7
resolve with remote
xianzhe-databricks Sep 29, 2025
09bfef3
lint
xianzhe-databricks Sep 26, 2025
238e2b7
fix ci
xianzhe-databricks Sep 29, 2025
8db164c
add docs
xianzhe-databricks Sep 29, 2025
b10fc9d
add tests for arrow python udtf
xianzhe-databricks Sep 29, 2025
7d3cf56
add a shield
xianzhe-databricks Sep 29, 2025
922c1a4
rename conf' also make classic dataframe API work with the conf
xianzhe-databricks Sep 30, 2025
16e8428
add size
xianzhe-databricks Sep 30, 2025
fd5fcb2
apply binary as bytes consistently in all cases
xianzhe-databricks Sep 30, 2025
fd687f2
reformat
xianzhe-databricks Sep 30, 2025
ee5d2c2
remove usearrow
xianzhe-databricks Sep 30, 2025
31f52c8
udf test restructure
xianzhe-databricks Oct 1, 2025
d63d7b3
doc
xianzhe-databricks Oct 1, 2025
91aed20
fix ci
xianzhe-databricks Oct 1, 2025
9ce5dfd
fix foreach partition
xianzhe-databricks Oct 1, 2025
e488e3b
add tests for nested structure
xianzhe-databricks Oct 1, 2025
4683449
address comments
xianzhe-databricks Oct 2, 2025
852f9d9
fix build and simplify tests
xianzhe-databricks Oct 2, 2025
b933d61
address comments
xianzhe-databricks Oct 7, 2025
fa202b1
move utils
xianzhe-databricks Oct 7, 2025
9d0c72e
hope to fix mypy
xianzhe-databricks Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sql-ref-datatypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ from pyspark.sql.types import *
|**StringType**|str|StringType()|
|**CharType(length)**|str|CharType(length)|
|**VarcharType(length)**|str|VarcharType(length)|
|**BinaryType**|bytearray|BinaryType()|
|**BinaryType**|bytes|BinaryType()|
|**BooleanType**|bool|BooleanType()|
|**TimestampType**|datetime.datetime|TimestampType()|
|**TimestampNTZType**|datetime.datetime|TimestampNTZType()|
Expand Down
4 changes: 3 additions & 1 deletion python/docs/source/tutorial/sql/type_conversions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ are listed below:
* - spark.sql.execution.pandas.inferPandasDictAsMap
- When enabled, Pandas dictionaries are inferred as MapType. Otherwise, they are inferred as StructType.
- False
* - spark.sql.execution.pyspark.binaryAsBytes
- Introduced in Spark 4.1.0. When enabled, BinaryType is mapped consistently to Python bytes; when disabled, matches the PySpark default behavior before 4.1.0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can explain briefly what is the default behavior before 4.1.0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I end up explaining more in the definition of spark.sql.execution.pyspark.binaryAsBytes in the SQL conf as it is really long...


All Conversions
---------------
Expand Down Expand Up @@ -105,7 +107,7 @@ All Conversions
- string
- StringType()
* - **BinaryType**
- bytearray
- bytes
- BinaryType()
* - **BooleanType**
- bool
Expand Down
21 changes: 18 additions & 3 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,16 @@ def collect(self) -> List[Row]:

assert schema is not None and isinstance(schema, StructType)

return ArrowTableToRowsConversion.convert(table, schema)
return ArrowTableToRowsConversion.convert(
table, schema, binary_as_bytes=self._get_binary_as_bytes()
)

def _get_binary_as_bytes(self) -> bool:
"""Get the binary_as_bytes configuration value from Spark session."""
return (
self._session.conf.get("spark.sql.execution.pyspark.binaryAsBytes", "true").lower()
== "true"
)

def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
query = self._plan.to_proto(self._session.client)
Expand Down Expand Up @@ -2075,7 +2084,9 @@ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]:
table = schema_or_table
if schema is None:
schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
yield from ArrowTableToRowsConversion.convert(table, schema)
yield from ArrowTableToRowsConversion.convert(
table, schema, binary_as_bytes=self._get_binary_as_bytes()
)

def pandas_api(
self, index_col: Optional[Union[str, List[str]]] = None
Expand Down Expand Up @@ -2161,8 +2172,12 @@ def foreach_func(row: Any) -> None:

def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
schema = self._schema
binary_as_bytes = self._get_binary_as_bytes()
field_converters = [
ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields
ArrowTableToRowsConversion._create_converter(
f.dataType, binary_as_bytes=binary_as_bytes
)
for f in schema.fields
]

def foreach_partition_func(itr: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
Expand Down
28 changes: 18 additions & 10 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,13 +518,13 @@ def _create_converter(dataType: DataType) -> Callable:
@overload
@staticmethod
def _create_converter(
dataType: DataType, *, none_on_identity: bool = True
dataType: DataType, *, none_on_identity: bool = True, binary_as_bytes: bool = True
) -> Optional[Callable]:
pass

@staticmethod
def _create_converter(
dataType: DataType, *, none_on_identity: bool = False
dataType: DataType, *, none_on_identity: bool = False, binary_as_bytes: bool = True
) -> Optional[Callable]:
assert dataType is not None and isinstance(dataType, DataType)

Expand All @@ -542,7 +542,9 @@ def _create_converter(
dedup_field_names = _dedup_names(field_names)

field_convs = [
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
for f in dataType.fields
]

Expand All @@ -564,7 +566,7 @@ def convert_struct(value: Any) -> Any:

elif isinstance(dataType, ArrayType):
element_conv = ArrowTableToRowsConversion._create_converter(
dataType.elementType, none_on_identity=True
dataType.elementType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if element_conv is None:
Expand All @@ -589,10 +591,10 @@ def convert_array(value: Any) -> Any:

elif isinstance(dataType, MapType):
key_conv = ArrowTableToRowsConversion._create_converter(
dataType.keyType, none_on_identity=True
dataType.keyType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
value_conv = ArrowTableToRowsConversion._create_converter(
dataType.valueType, none_on_identity=True
dataType.valueType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if key_conv is None:
Expand Down Expand Up @@ -646,7 +648,7 @@ def convert_binary(value: Any) -> Any:
return None
else:
assert isinstance(value, bytes)
return bytearray(value)
return value if binary_as_bytes else bytearray(value)

return convert_binary

Expand Down Expand Up @@ -676,7 +678,7 @@ def convert_timestample_ntz(value: Any) -> Any:
udt: UserDefinedType = dataType

conv = ArrowTableToRowsConversion._create_converter(
udt.sqlType(), none_on_identity=True
udt.sqlType(), none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if conv is None:
Expand Down Expand Up @@ -735,7 +737,11 @@ def convert(

@staticmethod # type: ignore[misc]
def convert(
table: "pa.Table", schema: StructType, *, return_as_tuples: bool = False
table: "pa.Table",
schema: StructType,
*,
return_as_tuples: bool = False,
binary_as_bytes: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this should be controlled by flag?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is already controlled by a flag, as binary_as_bytes is passed at caller's place as the value of the SQL conf spark.sql.execution.pyspark.binaryAsBytes.

It is not possible, or against the style, to access the SQL conf in this conversion.py

) -> List[Union[Row, tuple]]:
require_minimum_pyarrow_version()
import pyarrow as pa
Expand All @@ -748,7 +754,9 @@ def convert(

if len(fields) > 0:
field_converters = [
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
for f in schema.fields
]

Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,14 +854,17 @@ class ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
int_to_decimal_coercion_enabled : bool
If True, applies additional coercions in Python before converting to Arrow
This has performance penalties.
binary_as_bytes : bool
If True, binary type will be deserialized as bytes, otherwise as bytearray.
"""

def __init__(
self,
timezone,
safecheck,
input_types,
int_to_decimal_coercion_enabled=False,
int_to_decimal_coercion_enabled,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default value for int_to_decimal_coercion_enabled is not used at all

binary_as_bytes,
):
super().__init__(
timezone=timezone,
Expand All @@ -871,6 +874,7 @@ def __init__(
)
self._input_types = input_types
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
self._binary_as_bytes = binary_as_bytes

def load_stream(self, stream):
"""
Expand All @@ -887,7 +891,9 @@ def load_stream(self, stream):
List of columns containing list of Python values.
"""
converters = [
ArrowTableToRowsConversion._create_converter(dt, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
dt, none_on_identity=True, binary_as_bytes=self._binary_as_bytes
)
for dt in self._input_types
]

Expand Down
101 changes: 99 additions & 2 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@

from pyspark.errors import AnalysisException, PythonException, PySparkNotImplementedError
from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.functions import udf, col
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
from pyspark.sql.types import DayTimeIntervalType, VarcharType, StructType, StructField, StringType
from pyspark.sql.types import (
ArrayType,
BinaryType,
DayTimeIntervalType,
IntegerType,
MapType,
StringType,
StructField,
StructType,
VarcharType,
)
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
Expand Down Expand Up @@ -422,6 +432,93 @@ def tearDownClass(cls):
finally:
super(ArrowPythonUDFNonLegacyTests, cls).tearDownClass()

def test_udf_binary_type(self):
def get_binary_type(x):
return type(x).__name__

binary_udf = udf(get_binary_type, returnType="string", useArrow=True)

df = self.spark.createDataFrame(
[Row(b=b"hello"), Row(b=b"world")], schema=StructType([StructField("b", BinaryType())])
)

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "true"}):
result = df.select(binary_udf(col("b")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")
self.assertEqual(result[1]["type_name"], "bytes")

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "false"}):
result = df.select(binary_udf(col("b")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytearray")
self.assertEqual(result[1]["type_name"], "bytearray")

def test_udf_array_binary_type(self):
def check_array_binary_types(arr):
return [type(x).__name__ for x in arr]

array_binary_udf = udf(check_array_binary_types, returnType="array<string>", useArrow=True)

df = self.spark.createDataFrame(
[Row(arr_b=[b"a", b"b"]), Row(arr_b=[b"c", b"d"])],
schema=StructType([StructField("arr_b", ArrayType(BinaryType()))]),
)

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "true"}):
result = df.select(array_binary_udf(col("arr_b")).alias("types")).collect()
self.assertEqual(result[0]["types"], ["bytes", "bytes"])
self.assertEqual(result[1]["types"], ["bytes", "bytes"])

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "false"}):
result = df.select(array_binary_udf(col("arr_b")).alias("types")).collect()
self.assertEqual(result[0]["types"], ["bytearray", "bytearray"])
self.assertEqual(result[1]["types"], ["bytearray", "bytearray"])

def test_udf_map_binary_type(self):
def check_map_binary_types(m):
return [type(v).__name__ for v in m.values()]

map_binary_udf = udf(check_map_binary_types, returnType="array<string>", useArrow=True)

df = self.spark.createDataFrame(
[Row(map_b={"k1": b"v1", "k2": b"v2"}), Row(map_b={"k3": b"v3"})],
schema=StructType([StructField("map_b", MapType(StringType(), BinaryType()))]),
)

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "true"}):
result = df.select(map_binary_udf(col("map_b")).alias("types")).collect()
self.assertEqual(set(result[0]["types"]), {"bytes"})
self.assertEqual(result[1]["types"], ["bytes"])

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "false"}):
result = df.select(map_binary_udf(col("map_b")).alias("types")).collect()
self.assertEqual(set(result[0]["types"]), {"bytearray"})
self.assertEqual(result[1]["types"], ["bytearray"])

def test_udf_struct_binary_type(self):
def check_struct_binary_type(s):
return type(s.b).__name__

struct_binary_udf = udf(check_struct_binary_type, returnType="string", useArrow=True)

struct_schema = StructType(
[StructField("i", IntegerType()), StructField("b", BinaryType())]
)

df = self.spark.createDataFrame(
[Row(struct_b=Row(i=1, b=b"data1")), Row(struct_b=Row(i=2, b=b"data2"))],
schema=StructType([StructField("struct_b", struct_schema)]),
)

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "true"}):
result = df.select(struct_binary_udf(col("struct_b")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")
self.assertEqual(result[1]["type_name"], "bytes")

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "false"}):
result = df.select(struct_binary_udf(col("struct_b")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytearray")
self.assertEqual(result[1]["type_name"], "bytearray")


if __name__ == "__main__":
from pyspark.sql.tests.arrow.test_arrow_python_udf import * # noqa: F401
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,20 @@ def Interrupt(self, req: proto.InterruptRequest, metadata):
resp.session_id = self._session_id
return resp

def Config(self, req: proto.ConfigRequest, metadata):
self.req = req
resp = proto.ConfigResponse()
resp.session_id = self._session_id
if req.operation.HasField("get"):
pair = resp.pairs.add()
pair.key = req.operation.get.keys[0]
pair.value = "true" # Default value
elif req.operation.HasField("get_with_default"):
pair = resp.pairs.add()
pair.key = req.operation.get_with_default.pairs[0].key
pair.value = req.operation.get_with_default.pairs[0].value or "true"
return resp

# The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster)
# and it blocks the test process exiting because it is registered as the atexit handler
# in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test.
Expand Down
Loading