Skip to content
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ac5d83c
first commit
xianzhe-databricks Sep 26, 2025
1e702c9
data source related
xianzhe-databricks Sep 26, 2025
659fe4e
add conf on connect
xianzhe-databricks Sep 26, 2025
e791532
add partial tests
xianzhe-databricks Sep 26, 2025
6787bfc
Update python/pyspark/sql/tests/arrow/test_arrow_binary_as_bytes_udf.py
HyukjinKwon Sep 28, 2025
453298d
add tests for spark connect
xianzhe-databricks Sep 26, 2025
c859ce7
resolve with remote
xianzhe-databricks Sep 29, 2025
09bfef3
lint
xianzhe-databricks Sep 26, 2025
238e2b7
fix ci
xianzhe-databricks Sep 29, 2025
8db164c
add docs
xianzhe-databricks Sep 29, 2025
b10fc9d
add tests for arrow python udtf
xianzhe-databricks Sep 29, 2025
7d3cf56
add a shield
xianzhe-databricks Sep 29, 2025
922c1a4
rename conf' also make classic dataframe API work with the conf
xianzhe-databricks Sep 30, 2025
16e8428
add size
xianzhe-databricks Sep 30, 2025
fd5fcb2
apply binary as bytes consistently in all cases
xianzhe-databricks Sep 30, 2025
fd687f2
reformat
xianzhe-databricks Sep 30, 2025
ee5d2c2
remove usearrow
xianzhe-databricks Sep 30, 2025
31f52c8
udf test restructure
xianzhe-databricks Oct 1, 2025
d63d7b3
doc
xianzhe-databricks Oct 1, 2025
91aed20
fix ci
xianzhe-databricks Oct 1, 2025
9ce5dfd
fix foreach partition
xianzhe-databricks Oct 1, 2025
e488e3b
add tests for nested structure
xianzhe-databricks Oct 1, 2025
4683449
address comments
xianzhe-databricks Oct 2, 2025
852f9d9
fix build and simplify tests
xianzhe-databricks Oct 2, 2025
b933d61
address comments
xianzhe-databricks Oct 7, 2025
fa202b1
move utils
xianzhe-databricks Oct 7, 2025
9d0c72e
hope to fix mypy
xianzhe-databricks Oct 8, 2025
78f54b9
Fix mypy errors for binary_as_bytes parameter
xianzhe-databricks Oct 15, 2025
61562df
Fix mypy error: add type ignore comment for field converter call
xianzhe-databricks Oct 15, 2025
1bcc68e
doc
xianzhe-databricks Oct 16, 2025
f6d87de
minor change on migration guide
xianzhe-databricks Oct 16, 2025
b8b7cc0
merge with master
xianzhe-databricks Oct 20, 2025
1a72098
nits
xianzhe-databricks Oct 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sql-ref-datatypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ from pyspark.sql.types import *
|**StringType**|str|StringType()|
|**CharType(length)**|str|CharType(length)|
|**VarcharType(length)**|str|VarcharType(length)|
|**BinaryType**|bytearray|BinaryType()|
|**BinaryType**|bytes|BinaryType()|
|**BooleanType**|bool|BooleanType()|
|**TimestampType**|datetime.datetime|TimestampType()|
|**TimestampNTZType**|datetime.datetime|TimestampNTZType()|
Expand Down
5 changes: 4 additions & 1 deletion python/docs/source/tutorial/sql/type_conversions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ are listed below:
* - spark.sql.execution.pandas.inferPandasDictAsMap
- When enabled, Pandas dictionaries are inferred as MapType. Otherwise, they are inferred as StructType.
- False
* - spark.sql.execution.pyspark.binaryAsBytes
- Introduced in Spark 4.1.0. When enabled, BinaryType is mapped consistently to Python bytes; when disabled, matches the PySpark default behavior before 4.1.0.
- True

All Conversions
---------------
Expand Down Expand Up @@ -105,7 +108,7 @@ All Conversions
- string
- StringType()
* - **BinaryType**
- bytearray
- bytes
- BinaryType()
* - **BooleanType**
- bool
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def from_avro(
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> avroDf = df.select(to_avro(df.value).alias("avro"))
>>> avroDf.collect()
[Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))]
[Row(avro=b'\\x00\\x00\\x04\\x00\\nAlice')]

>>> jsonFormatSchema = '''{"type":"record","name":"topLevelRecord","fields":
... [{"name":"avro","type":[{"type":"record","name":"value","namespace":"topLevelRecord",
Expand Down Expand Up @@ -141,12 +141,12 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
>>> data = ['SPADES']
>>> df = spark.createDataFrame(data, "string")
>>> df.select(to_avro(df.value).alias("suite")).collect()
[Row(suite=bytearray(b'\\x00\\x0cSPADES'))]
[Row(suite=b'\\x00\\x0cSPADES')]

>>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value",
... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]'''
>>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect()
[Row(suite=bytearray(b'\\x02\\x00'))]
[Row(suite=b'\\x02\\x00')]
"""
from py4j.java_gateway import JVMView
from pyspark.sql.classic.column import _to_java_column
Expand Down
19 changes: 16 additions & 3 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,14 @@ def collect(self) -> List[Row]:

assert schema is not None and isinstance(schema, StructType)

return ArrowTableToRowsConversion.convert(table, schema)
return ArrowTableToRowsConversion.convert(
table, schema, binary_as_bytes=self._get_binary_as_bytes()
)

def _get_binary_as_bytes(self) -> bool:
"""Get the binary_as_bytes configuration value from Spark session."""
conf_value = self._session.conf.get("spark.sql.execution.pyspark.binaryAsBytes", "true")
return (conf_value or "true").lower() == "true"

def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
query = self._plan.to_proto(self._session.client)
Expand Down Expand Up @@ -2075,7 +2082,9 @@ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]:
table = schema_or_table
if schema is None:
schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
yield from ArrowTableToRowsConversion.convert(table, schema)
yield from ArrowTableToRowsConversion.convert(
table, schema, binary_as_bytes=self._get_binary_as_bytes()
)

def pandas_api(
self, index_col: Optional[Union[str, List[str]]] = None
Expand Down Expand Up @@ -2161,8 +2170,12 @@ def foreach_func(row: Any) -> None:

def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
schema = self._schema
binary_as_bytes = self._get_binary_as_bytes()
field_converters = [
ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=False, binary_as_bytes=binary_as_bytes
)
for f in schema.fields
]

def foreach_partition_func(itr: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
Expand Down
33 changes: 24 additions & 9 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def _create_converter(

@staticmethod
def _create_converter(
dataType: DataType, *, none_on_identity: bool = False
dataType: DataType, *, none_on_identity: bool = False, binary_as_bytes: bool = True
) -> Optional[Callable]:
assert dataType is not None and isinstance(dataType, DataType)

Expand All @@ -542,7 +542,9 @@ def _create_converter(
dedup_field_names = _dedup_names(field_names)

field_convs = [
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
for f in dataType.fields
]

Expand All @@ -564,7 +566,7 @@ def convert_struct(value: Any) -> Any:

elif isinstance(dataType, ArrayType):
element_conv = ArrowTableToRowsConversion._create_converter(
dataType.elementType, none_on_identity=True
dataType.elementType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if element_conv is None:
Expand All @@ -589,10 +591,10 @@ def convert_array(value: Any) -> Any:

elif isinstance(dataType, MapType):
key_conv = ArrowTableToRowsConversion._create_converter(
dataType.keyType, none_on_identity=True
dataType.keyType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
value_conv = ArrowTableToRowsConversion._create_converter(
dataType.valueType, none_on_identity=True
dataType.valueType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if key_conv is None:
Expand Down Expand Up @@ -646,7 +648,7 @@ def convert_binary(value: Any) -> Any:
return None
else:
assert isinstance(value, bytes)
return bytearray(value)
return value if binary_as_bytes else bytearray(value)

return convert_binary

Expand Down Expand Up @@ -676,7 +678,7 @@ def convert_timestample_ntz(value: Any) -> Any:
udt: UserDefinedType = dataType

conv = ArrowTableToRowsConversion._create_converter(
udt.sqlType(), none_on_identity=True
udt.sqlType(), none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if conv is None:
Expand Down Expand Up @@ -726,6 +728,13 @@ def convert( # type: ignore[overload-overlap]
) -> List[Row]:
pass

@overload
@staticmethod
def convert(
table: "pa.Table", schema: StructType, *, binary_as_bytes: bool = True
) -> List[Row]:
pass

@overload
@staticmethod
def convert(
Expand All @@ -735,7 +744,11 @@ def convert(

@staticmethod # type: ignore[misc]
def convert(
table: "pa.Table", schema: StructType, *, return_as_tuples: bool = False
table: "pa.Table",
schema: StructType,
*,
return_as_tuples: bool = False,
binary_as_bytes: bool = True,
) -> List[Union[Row, tuple]]:
require_minimum_pyarrow_version()
import pyarrow as pa
Expand All @@ -748,7 +761,9 @@ def convert(

if len(fields) > 0:
field_converters = [
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
for f in schema.fields
]

Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,14 +854,17 @@ class ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
int_to_decimal_coercion_enabled : bool
If True, applies additional coercions in Python before converting to Arrow
This has performance penalties.
binary_as_bytes : bool
If True, binary type will be deserialized as bytes, otherwise as bytearray.
"""

def __init__(
self,
timezone,
safecheck,
input_types,
int_to_decimal_coercion_enabled=False,
int_to_decimal_coercion_enabled,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default value for int_to_decimal_coercion_enabled is not used at all

binary_as_bytes,
):
super().__init__(
timezone=timezone,
Expand All @@ -871,6 +874,7 @@ def __init__(
)
self._input_types = input_types
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
self._binary_as_bytes = binary_as_bytes

def load_stream(self, stream):
"""
Expand All @@ -887,7 +891,9 @@ def load_stream(self, stream):
List of columns containing list of Python values.
"""
converters = [
ArrowTableToRowsConversion._create_converter(dt, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
dt, none_on_identity=True, binary_as_bytes=self._binary_as_bytes
)
for dt in self._input_types
]

Expand Down
99 changes: 97 additions & 2 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@

from pyspark.errors import AnalysisException, PythonException, PySparkNotImplementedError
from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.functions import udf, col
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
from pyspark.sql.types import DayTimeIntervalType, VarcharType, StructType, StructField, StringType
from pyspark.sql.types import (
BinaryType,
DayTimeIntervalType,
StringType,
StructField,
StructType,
VarcharType,
)
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
Expand Down Expand Up @@ -364,6 +371,94 @@ def tearDownClass(cls):
finally:
super().tearDownClass()

def test_udf_binary_type(self):
def get_binary_type(x):
return type(x).__name__

binary_udf = udf(get_binary_type, returnType="string", useArrow=True)

df = self.spark.createDataFrame(
[Row(b=b"hello"), Row(b=b"world")], schema=StructType([StructField("b", BinaryType())])
)
# For Arrow Python UDF with legacy conversion BinaryType is always mapped to bytes
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "true"}):
result = df.select(binary_udf(col("b")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")
self.assertEqual(result[1]["type_name"], "bytes")

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "false"}):
result = df.select(binary_udf(col("b")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")
self.assertEqual(result[1]["type_name"], "bytes")

def test_udf_binary_type_in_nested_structures(self):
# For Arrow Python UDF with legacy conversion BinaryType is always mapped to bytes
# Test binary in array
def check_array_binary_type(arr):
return type(arr[0]).__name__

array_udf = udf(check_array_binary_type, returnType="string")
df_array = self.spark.createDataFrame(
[Row(arr=[b"hello", b"world"])],
schema=StructType([StructField("arr", ArrayType(BinaryType()))]),
)

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "true"}):
result = df_array.select(array_udf(col("arr")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "false"}):
result = df_array.select(array_udf(col("arr")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")

# Test binary in map value
def check_map_binary_type(m):
return type(list(m.values())[0]).__name__

map_udf = udf(check_map_binary_type, returnType="string")
df_map = self.spark.createDataFrame(
[Row(m={"key": b"value"})],
schema=StructType([StructField("m", MapType(StringType(), BinaryType()))]),
)

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "true"}):
result = df_map.select(map_udf(col("m")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "false"}):
result = df_map.select(map_udf(col("m")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")

# Test binary in struct
def check_struct_binary_type(s):
return type(s.binary_field).__name__

struct_udf = udf(check_struct_binary_type, returnType="string")
df_struct = self.spark.createDataFrame(
[Row(s=Row(binary_field=b"test", other_field="value"))],
schema=StructType(
[
StructField(
"s",
StructType(
[
StructField("binary_field", BinaryType()),
StructField("other_field", StringType()),
]
),
)
]
),
)

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "true"}):
result = df_struct.select(struct_udf(col("s")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")

with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": "false"}):
result = df_struct.select(struct_udf(col("s")).alias("type_name")).collect()
self.assertEqual(result[0]["type_name"], "bytes")


class ArrowPythonUDFNonLegacyTestsMixin(ArrowPythonUDFTestsMixin):
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.
#

import unittest

from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests
from pyspark.sql.tests.arrow.test_arrow_python_udf import ArrowPythonUDFTestsMixin

Expand Down Expand Up @@ -46,6 +48,14 @@ def tearDownClass(cls):
finally:
super().tearDownClass()

@unittest.skip("Duplicate test as it is already tested in ArrowPythonUDFLegacyTests.")
def test_udf_binary_type(self):
super().test_udf_binary_type(self)

@unittest.skip("Duplicate test as it is already tested in ArrowPythonUDFLegacyTests.")
def test_udf_binary_type_in_nested_structures(self):
super().test_udf_binary_type_in_nested_structures(self)


class ArrowPythonUDFParityNonLegacyTestsMixin(ArrowPythonUDFTestsMixin):
@classmethod
Expand All @@ -62,6 +72,14 @@ def tearDownClass(cls):
finally:
super().tearDownClass()

@unittest.skip("Duplicate test as it is already tested in ArrowPythonUDFNonLegacyTests.")
def test_udf_binary_type(self):
super().test_udf_binary_type(self)

@unittest.skip("Duplicate test as it is already tested in ArrowPythonUDFLegacyTests.")
def test_udf_binary_type_in_nested_structures(self):
super().test_udf_binary_type_in_nested_structures(self)


class ArrowPythonUDFParityLegacyTests(UDFParityTests, ArrowPythonUDFParityLegacyTestsMixin):
@classmethod
Expand Down
Loading