Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ac5d83c
first commit
xianzhe-databricks Sep 26, 2025
1e702c9
data source related
xianzhe-databricks Sep 26, 2025
659fe4e
add conf on connect
xianzhe-databricks Sep 26, 2025
e791532
add partial tests
xianzhe-databricks Sep 26, 2025
6787bfc
Update python/pyspark/sql/tests/arrow/test_arrow_binary_as_bytes_udf.py
HyukjinKwon Sep 28, 2025
453298d
add tests for spark connect
xianzhe-databricks Sep 26, 2025
c859ce7
resolve with remote
xianzhe-databricks Sep 29, 2025
09bfef3
lint
xianzhe-databricks Sep 26, 2025
238e2b7
fix ci
xianzhe-databricks Sep 29, 2025
8db164c
add docs
xianzhe-databricks Sep 29, 2025
b10fc9d
add tests for arrow python udtf
xianzhe-databricks Sep 29, 2025
7d3cf56
add a shield
xianzhe-databricks Sep 29, 2025
922c1a4
rename conf' also make classic dataframe API work with the conf
xianzhe-databricks Sep 30, 2025
16e8428
add size
xianzhe-databricks Sep 30, 2025
fd5fcb2
apply binary as bytes consistently in all cases
xianzhe-databricks Sep 30, 2025
fd687f2
reformat
xianzhe-databricks Sep 30, 2025
ee5d2c2
remove usearrow
xianzhe-databricks Sep 30, 2025
31f52c8
udf test restructure
xianzhe-databricks Oct 1, 2025
d63d7b3
doc
xianzhe-databricks Oct 1, 2025
91aed20
fix ci
xianzhe-databricks Oct 1, 2025
9ce5dfd
fix foreach partition
xianzhe-databricks Oct 1, 2025
e488e3b
add tests for nested structure
xianzhe-databricks Oct 1, 2025
4683449
address comments
xianzhe-databricks Oct 2, 2025
852f9d9
fix build and simplify tests
xianzhe-databricks Oct 2, 2025
b933d61
address comments
xianzhe-databricks Oct 7, 2025
fa202b1
move utils
xianzhe-databricks Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sql-ref-datatypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ from pyspark.sql.types import *
|**StringType**|str|StringType()|
|**CharType(length)**|str|CharType(length)|
|**VarcharType(length)**|str|VarcharType(length)|
|**BinaryType**|bytearray|BinaryType()|
|**BinaryType**|bytes|BinaryType()|
|**BooleanType**|bool|BooleanType()|
|**TimestampType**|datetime.datetime|TimestampType()|
|**TimestampNTZType**|datetime.datetime|TimestampNTZType()|
Expand Down
5 changes: 4 additions & 1 deletion python/docs/source/tutorial/sql/type_conversions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ are listed below:
* - spark.sql.execution.pandas.inferPandasDictAsMap
- When enabled, Pandas dictionaries are inferred as MapType. Otherwise, they are inferred as StructType.
- False
* - spark.sql.execution.pyspark.binaryAsBytes
- Introduced in Spark 4.1.0. When enabled, BinaryType is mapped consistently to Python bytes; when disabled, matches the PySpark default behavior before 4.1.0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can explain briefly what is the default behavior before 4.1.0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I end up explaining more in the definition of spark.sql.execution.pyspark.binaryAsBytes in the SQL conf as it is really long...

- True

All Conversions
---------------
Expand Down Expand Up @@ -105,7 +108,7 @@ All Conversions
- string
- StringType()
* - **BinaryType**
- bytearray
- bytes
- BinaryType()
* - **BooleanType**
- bool
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def from_avro(
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> avroDf = df.select(to_avro(df.value).alias("avro"))
>>> avroDf.collect()
[Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))]
[Row(avro=b'\\x00\\x00\\x04\\x00\\nAlice')]
>>> jsonFormatSchema = '''{"type":"record","name":"topLevelRecord","fields":
... [{"name":"avro","type":[{"type":"record","name":"value","namespace":"topLevelRecord",
Expand Down Expand Up @@ -141,12 +141,12 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
>>> data = ['SPADES']
>>> df = spark.createDataFrame(data, "string")
>>> df.select(to_avro(df.value).alias("suite")).collect()
[Row(suite=bytearray(b'\\x00\\x0cSPADES'))]
[Row(suite=b'\\x00\\x0cSPADES')]
>>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value",
... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]'''
>>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect()
[Row(suite=bytearray(b'\\x02\\x00'))]
[Row(suite=b'\\x02\\x00')]
"""
from py4j.java_gateway import JVMView
from pyspark.sql.classic.column import _to_java_column
Expand Down
19 changes: 16 additions & 3 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,14 @@ def collect(self) -> List[Row]:

assert schema is not None and isinstance(schema, StructType)

return ArrowTableToRowsConversion.convert(table, schema)
return ArrowTableToRowsConversion.convert(
table, schema, binary_as_bytes=self._get_binary_as_bytes()
)

def _get_binary_as_bytes(self) -> bool:
"""Get the binary_as_bytes configuration value from Spark session."""
conf_value = self._session.conf.get("spark.sql.execution.pyspark.binaryAsBytes", "true")
return conf_value is not None and conf_value.lower() == "true"

def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
query = self._plan.to_proto(self._session.client)
Expand Down Expand Up @@ -2075,7 +2082,9 @@ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]:
table = schema_or_table
if schema is None:
schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
yield from ArrowTableToRowsConversion.convert(table, schema)
yield from ArrowTableToRowsConversion.convert(
table, schema, binary_as_bytes=self._get_binary_as_bytes()
)

def pandas_api(
self, index_col: Optional[Union[str, List[str]]] = None
Expand Down Expand Up @@ -2161,8 +2170,12 @@ def foreach_func(row: Any) -> None:

def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
schema = self._schema
binary_as_bytes = self._get_binary_as_bytes()
field_converters = [
ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=False, binary_as_bytes=binary_as_bytes
)
for f in schema.fields
]

def foreach_partition_func(itr: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
Expand Down
33 changes: 24 additions & 9 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def _create_converter(

@staticmethod
def _create_converter(
dataType: DataType, *, none_on_identity: bool = False
dataType: DataType, *, none_on_identity: bool = False, binary_as_bytes: bool = True
) -> Optional[Callable]:
assert dataType is not None and isinstance(dataType, DataType)

Expand All @@ -542,7 +542,9 @@ def _create_converter(
dedup_field_names = _dedup_names(field_names)

field_convs = [
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
for f in dataType.fields
]

Expand All @@ -564,7 +566,7 @@ def convert_struct(value: Any) -> Any:

elif isinstance(dataType, ArrayType):
element_conv = ArrowTableToRowsConversion._create_converter(
dataType.elementType, none_on_identity=True
dataType.elementType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if element_conv is None:
Expand All @@ -589,10 +591,10 @@ def convert_array(value: Any) -> Any:

elif isinstance(dataType, MapType):
key_conv = ArrowTableToRowsConversion._create_converter(
dataType.keyType, none_on_identity=True
dataType.keyType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
value_conv = ArrowTableToRowsConversion._create_converter(
dataType.valueType, none_on_identity=True
dataType.valueType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if key_conv is None:
Expand Down Expand Up @@ -646,7 +648,7 @@ def convert_binary(value: Any) -> Any:
return None
else:
assert isinstance(value, bytes)
return bytearray(value)
return value if binary_as_bytes else bytearray(value)

return convert_binary

Expand Down Expand Up @@ -676,7 +678,7 @@ def convert_timestample_ntz(value: Any) -> Any:
udt: UserDefinedType = dataType

conv = ArrowTableToRowsConversion._create_converter(
udt.sqlType(), none_on_identity=True
udt.sqlType(), none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if conv is None:
Expand Down Expand Up @@ -726,6 +728,13 @@ def convert( # type: ignore[overload-overlap]
) -> List[Row]:
pass

@overload
@staticmethod
def convert(
table: "pa.Table", schema: StructType, *, binary_as_bytes: bool = True
) -> List[Row]:
pass

@overload
@staticmethod
def convert(
Expand All @@ -735,7 +744,11 @@ def convert(

@staticmethod # type: ignore[misc]
def convert(
table: "pa.Table", schema: StructType, *, return_as_tuples: bool = False
table: "pa.Table",
schema: StructType,
*,
return_as_tuples: bool = False,
binary_as_bytes: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this should be controlled by flag?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is already controlled by a flag, as binary_as_bytes is passed at caller's place as the value of the SQL conf spark.sql.execution.pyspark.binaryAsBytes.

It is not possible, or against the style, to access the SQL conf in this conversion.py

) -> List[Union[Row, tuple]]:
require_minimum_pyarrow_version()
import pyarrow as pa
Expand All @@ -748,7 +761,9 @@ def convert(

if len(fields) > 0:
field_converters = [
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
for f in schema.fields
]

Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16627,14 +16627,14 @@ def to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> C
>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([("abc",)], ["e"])
>>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect()
[Row(r=bytearray(b'abc'))]
[Row(r=b'abc')]

Example 2: Convert string to a timestamp without encoding specified

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([("414243",)], ["e"])
>>> df.select(sf.try_to_binary(df.e).alias('r')).collect()
[Row(r=bytearray(b'ABC'))]
[Row(r=b'ABC')]
"""
if format is not None:
return _invoke_function_over_columns("to_binary", col, format)
Expand Down Expand Up @@ -17650,14 +17650,14 @@ def try_to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None)
>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([("abc",)], ["e"])
>>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect()
[Row(r=bytearray(b'abc'))]
[Row(r=b'abc')]

Example 2: Convert string to a timestamp without encoding specified

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([("414243",)], ["e"])
>>> df.select(sf.try_to_binary(df.e).alias('r')).collect()
[Row(r=bytearray(b'ABC'))]
[Row(r=b'ABC')]

Example 3: Converion failure results in NULL when ANSI mode is on

Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,14 +854,17 @@ class ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
int_to_decimal_coercion_enabled : bool
If True, applies additional coercions in Python before converting to Arrow
This has performance penalties.
binary_as_bytes : bool
If True, binary type will be deserialized as bytes, otherwise as bytearray.
"""

def __init__(
self,
timezone,
safecheck,
input_types,
int_to_decimal_coercion_enabled=False,
int_to_decimal_coercion_enabled,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default value for int_to_decimal_coercion_enabled is not used at all

binary_as_bytes,
):
super().__init__(
timezone=timezone,
Expand All @@ -871,6 +874,7 @@ def __init__(
)
self._input_types = input_types
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
self._binary_as_bytes = binary_as_bytes

def load_stream(self, stream):
"""
Expand All @@ -887,7 +891,9 @@ def load_stream(self, stream):
List of columns containing list of Python values.
"""
converters = [
ArrowTableToRowsConversion._create_converter(dt, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
dt, none_on_identity=True, binary_as_bytes=self._binary_as_bytes
)
for dt in self._input_types
]

Expand Down
89 changes: 87 additions & 2 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,26 @@

from pyspark.errors import AnalysisException, PythonException, PySparkNotImplementedError
from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.functions import udf, col
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
from pyspark.sql.types import DayTimeIntervalType, VarcharType, StructType, StructField, StringType
from pyspark.sql.types import (
ArrayType,
BinaryType,
DayTimeIntervalType,
MapType,
StringType,
StructField,
StructType,
VarcharType,
)
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
ReusedSQLTestCase,
)
from pyspark.testing.utils import assertDataFrameEqual
from pyspark.util import PythonEvalType


Expand Down Expand Up @@ -364,6 +374,81 @@ def tearDownClass(cls):
finally:
super().tearDownClass()

def test_udf_binary_type(self):
def get_binary_type(x):
return type(x).__name__

binary_udf = udf(get_binary_type, returnType="string", useArrow=True)

df = self.spark.createDataFrame(
[Row(b=b"hello"), Row(b=b"world")], schema=StructType([StructField("b", BinaryType())])
)
expected = self.spark.createDataFrame([Row(type_name="bytes"), Row(type_name="bytes")])
# For Arrow Python UDF with legacy conversion BinaryType is always mapped to bytes
for conf_val in ["true", "false"]:
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_val}):
result = df.select(binary_udf(col("b")).alias("type_name"))
assertDataFrameEqual(result, expected)

def test_udf_binary_type_in_nested_structures(self):
# For Arrow Python UDF with legacy conversion BinaryType is always mapped to bytes
# Test binary in array
def check_array_binary_type(arr):
return type(arr[0]).__name__

array_udf = udf(check_array_binary_type, returnType="string")
df_array = self.spark.createDataFrame(
[Row(arr=[b"hello", b"world"])],
schema=StructType([StructField("arr", ArrayType(BinaryType()))]),
)
expected = self.spark.createDataFrame([Row(type_name="bytes")])
for conf_val in ["true", "false"]:
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_val}):
result = df_array.select(array_udf(col("arr")).alias("type_name"))
assertDataFrameEqual(result, expected)

# Test binary in map value
def check_map_binary_type(m):
return type(list(m.values())[0]).__name__

map_udf = udf(check_map_binary_type, returnType="string")
df_map = self.spark.createDataFrame(
[Row(m={"key": b"value"})],
schema=StructType([StructField("m", MapType(StringType(), BinaryType()))]),
)
expected = self.spark.createDataFrame([Row(type_name="bytes")])
for conf_val in ["true", "false"]:
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_val}):
result = df_map.select(map_udf(col("m")).alias("type_name"))
assertDataFrameEqual(result, expected)

# Test binary in struct
def check_struct_binary_type(s):
return type(s.binary_field).__name__

struct_udf = udf(check_struct_binary_type, returnType="string")
df_struct = self.spark.createDataFrame(
[Row(s=Row(binary_field=b"test", other_field="value"))],
schema=StructType(
[
StructField(
"s",
StructType(
[
StructField("binary_field", BinaryType()),
StructField("other_field", StringType()),
]
),
)
]
),
)
expected = self.spark.createDataFrame([Row(type_name="bytes")])
for conf_val in ["true", "false"]:
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_val}):
result = df_struct.select(struct_udf(col("s")).alias("type_name"))
assertDataFrameEqual(result, expected)


class ArrowPythonUDFNonLegacyTestsMixin(ArrowPythonUDFTestsMixin):
@classmethod
Expand Down
Loading