diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index 25f847508ce01..fd3e0b7a3164f 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -131,7 +131,7 @@ from pyspark.sql.types import * |**StringType**|str|StringType()| |**CharType(length)**|str|CharType(length)| |**VarcharType(length)**|str|VarcharType(length)| -|**BinaryType**|bytearray|BinaryType()| +|**BinaryType**|bytes|BinaryType()| |**BooleanType**|bool|BooleanType()| |**TimestampType**|datetime.datetime|TimestampType()| |**TimestampNTZType**|datetime.datetime|TimestampNTZType()| diff --git a/python/docs/source/tutorial/sql/type_conversions.rst b/python/docs/source/tutorial/sql/type_conversions.rst index 2f13701995ef2..82f6a428bb528 100644 --- a/python/docs/source/tutorial/sql/type_conversions.rst +++ b/python/docs/source/tutorial/sql/type_conversions.rst @@ -67,6 +67,9 @@ are listed below: * - spark.sql.execution.pandas.inferPandasDictAsMap - When enabled, Pandas dictionaries are inferred as MapType. Otherwise, they are inferred as StructType. - False + * - spark.sql.execution.pyspark.binaryAsBytes + - Introduced in Spark 4.1.0. When enabled, BinaryType is mapped consistently to Python bytes; when disabled, matches the PySpark default behavior before 4.1.0. + - True All Conversions --------------- @@ -105,7 +108,7 @@ All Conversions - string - StringType() * - **BinaryType** - - bytearray + - bytes - BinaryType() * - **BooleanType** - bool diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index 5c4be6570cd6a..e1b70c5fd3597 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -69,7 +69,7 @@ def from_avro( >>> df = spark.createDataFrame(data, ("key", "value")) >>> avroDf = df.select(to_avro(df.value).alias("avro")) >>> avroDf.collect() - [Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))] + [Row(avro=b'\\x00\\x00\\x04\\x00\\nAlice')] >>> jsonFormatSchema = '''{"type":"record","name":"topLevelRecord","fields": ... [{"name":"avro","type":[{"type":"record","name":"value","namespace":"topLevelRecord", @@ -141,12 +141,12 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: >>> data = ['SPADES'] >>> df = spark.createDataFrame(data, "string") >>> df.select(to_avro(df.value).alias("suite")).collect() - [Row(suite=bytearray(b'\\x00\\x0cSPADES'))] + [Row(suite=b'\\x00\\x0cSPADES')] >>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value", ... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]''' >>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect() - [Row(suite=bytearray(b'\\x02\\x00'))] + [Row(suite=b'\\x02\\x00')] """ from py4j.java_gateway import JVMView from pyspark.sql.classic.column import _to_java_column diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index ab7fdc90ba3c5..b3f5947e491b6 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1823,7 +1823,14 @@ def collect(self) -> List[Row]: assert schema is not None and isinstance(schema, StructType) - return ArrowTableToRowsConversion.convert(table, schema) + return ArrowTableToRowsConversion.convert( + table, schema, binary_as_bytes=self._get_binary_as_bytes() + ) + + def _get_binary_as_bytes(self) -> bool: + """Get the binary_as_bytes configuration value from Spark session.""" + conf_value = self._session.conf.get("spark.sql.execution.pyspark.binaryAsBytes", "true") + return conf_value is not None and conf_value.lower() == "true" def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]: query = self._plan.to_proto(self._session.client) @@ -2075,7 +2082,9 @@ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: table = schema_or_table if schema is None: schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=True) - yield from ArrowTableToRowsConversion.convert(table, schema) + yield from ArrowTableToRowsConversion.convert( + table, schema, binary_as_bytes=self._get_binary_as_bytes() + ) def pandas_api( self, index_col: Optional[Union[str, List[str]]] = None @@ -2161,8 +2170,12 @@ def foreach_func(row: Any) -> None: def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: schema = self._schema + binary_as_bytes = self._get_binary_as_bytes() field_converters = [ - ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields + ArrowTableToRowsConversion._create_converter( + f.dataType, none_on_identity=False, binary_as_bytes=binary_as_bytes + ) + for f in schema.fields ] def foreach_partition_func(itr: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]: diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index f1aa55c2039ac..694bd4186751f 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -518,13 +518,13 @@ def _create_converter(dataType: DataType) -> Callable: @overload @staticmethod def _create_converter( - dataType: DataType, *, none_on_identity: bool = True + dataType: DataType, *, none_on_identity: bool = True, binary_as_bytes: bool = True ) -> Optional[Callable]: pass @staticmethod def _create_converter( - dataType: DataType, *, none_on_identity: bool = False + dataType: DataType, *, none_on_identity: bool = False, binary_as_bytes: bool = True ) -> Optional[Callable]: assert dataType is not None and isinstance(dataType, DataType) @@ -542,7 +542,9 @@ def _create_converter( dedup_field_names = _dedup_names(field_names) field_convs = [ - ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True) + ArrowTableToRowsConversion._create_converter( + f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes + ) for f in dataType.fields ] @@ -553,7 +555,7 @@ def convert_struct(value: Any) -> Any: assert isinstance(value, dict) _values = [ - field_convs[i](value.get(name, None)) # type: ignore[misc] + field_convs[i](value.get(name, None)) if field_convs[i] is not None else value.get(name, None) for i, name in enumerate(dedup_field_names) @@ -564,7 +566,7 @@ def convert_struct(value: Any) -> Any: elif isinstance(dataType, ArrayType): element_conv = ArrowTableToRowsConversion._create_converter( - dataType.elementType, none_on_identity=True + dataType.elementType, none_on_identity=True, binary_as_bytes=binary_as_bytes ) if element_conv is None: @@ -589,10 +591,10 @@ def convert_array(value: Any) -> Any: elif isinstance(dataType, MapType): key_conv = ArrowTableToRowsConversion._create_converter( - dataType.keyType, none_on_identity=True + dataType.keyType, none_on_identity=True, binary_as_bytes=binary_as_bytes ) value_conv = ArrowTableToRowsConversion._create_converter( - dataType.valueType, none_on_identity=True + dataType.valueType, none_on_identity=True, binary_as_bytes=binary_as_bytes ) if key_conv is None: @@ -646,7 +648,7 @@ def convert_binary(value: Any) -> Any: return None else: assert isinstance(value, bytes) - return bytearray(value) + return value if binary_as_bytes else bytearray(value) return convert_binary @@ -676,7 +678,7 @@ def convert_timestample_ntz(value: Any) -> Any: udt: UserDefinedType = dataType conv = ArrowTableToRowsConversion._create_converter( - udt.sqlType(), none_on_identity=True + udt.sqlType(), none_on_identity=True, binary_as_bytes=binary_as_bytes ) if conv is None: @@ -721,21 +723,26 @@ def convert_variant(value: Any) -> Any: @overload @staticmethod - def convert( # type: ignore[overload-overlap] - table: "pa.Table", schema: StructType - ) -> List[Row]: + def convert(table: "pa.Table", schema: StructType) -> List[Row]: pass @overload @staticmethod - def convert( - table: "pa.Table", schema: StructType, *, return_as_tuples: bool = True - ) -> List[tuple]: + def convert(table: "pa.Table", schema: StructType, *, binary_as_bytes: bool) -> List[Row]: + pass + + @overload + @staticmethod + def convert(table: "pa.Table", schema: StructType, *, return_as_tuples: bool) -> List[tuple]: pass @staticmethod # type: ignore[misc] def convert( - table: "pa.Table", schema: StructType, *, return_as_tuples: bool = False + table: "pa.Table", + schema: StructType, + *, + return_as_tuples: bool = False, + binary_as_bytes: bool = True, ) -> List[Union[Row, tuple]]: require_minimum_pyarrow_version() import pyarrow as pa @@ -748,7 +755,9 @@ def convert( if len(fields) > 0: field_converters = [ - ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True) + ArrowTableToRowsConversion._create_converter( + f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes + ) for f in schema.fields ] diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index fb5d8ea461963..512af852babcd 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16627,14 +16627,14 @@ def to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> C >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("abc",)], ["e"]) >>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect() - [Row(r=bytearray(b'abc'))] + [Row(r=b'abc')] Example 2: Convert string to a timestamp without encoding specified >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("414243",)], ["e"]) >>> df.select(sf.try_to_binary(df.e).alias('r')).collect() - [Row(r=bytearray(b'ABC'))] + [Row(r=b'ABC')] """ if format is not None: return _invoke_function_over_columns("to_binary", col, format) @@ -17650,14 +17650,14 @@ def try_to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("abc",)], ["e"]) >>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect() - [Row(r=bytearray(b'abc'))] + [Row(r=b'abc')] Example 2: Convert string to a timestamp without encoding specified >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("414243",)], ["e"]) >>> df.select(sf.try_to_binary(df.e).alias('r')).collect() - [Row(r=bytearray(b'ABC'))] + [Row(r=b'ABC')] Example 3: Converion failure results in NULL when ANSI mode is on diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 801a87c06cc5d..9030cd89776d7 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -854,6 +854,8 @@ class ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer): int_to_decimal_coercion_enabled : bool If True, applies additional coercions in Python before converting to Arrow This has performance penalties. + binary_as_bytes : bool + If True, binary type will be deserialized as bytes, otherwise as bytearray. """ def __init__( @@ -861,7 +863,8 @@ def __init__( timezone, safecheck, input_types, - int_to_decimal_coercion_enabled=False, + int_to_decimal_coercion_enabled, + binary_as_bytes, ): super().__init__( timezone=timezone, @@ -871,6 +874,7 @@ def __init__( ) self._input_types = input_types self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled + self._binary_as_bytes = binary_as_bytes def load_stream(self, stream): """ @@ -887,7 +891,9 @@ def load_stream(self, stream): List of columns containing list of Python values. """ converters = [ - ArrowTableToRowsConversion._create_converter(dt, none_on_identity=True) + ArrowTableToRowsConversion._create_converter( + dt, none_on_identity=True, binary_as_bytes=self._binary_as_bytes + ) for dt in self._input_types ] diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py index a054261304c6f..5ea329ad5444d 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py @@ -19,9 +19,18 @@ from pyspark.errors import AnalysisException, PythonException, PySparkNotImplementedError from pyspark.sql import Row -from pyspark.sql.functions import udf +from pyspark.sql.functions import udf, col from pyspark.sql.tests.test_udf import BaseUDFTestsMixin -from pyspark.sql.types import DayTimeIntervalType, VarcharType, StructType, StructField, StringType +from pyspark.sql.types import ( + ArrayType, + BinaryType, + DayTimeIntervalType, + MapType, + StringType, + StructField, + StructType, + VarcharType, +) from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, @@ -29,6 +38,7 @@ pyarrow_requirement_message, ReusedSQLTestCase, ) +from pyspark.testing.utils import assertDataFrameEqual from pyspark.util import PythonEvalType @@ -364,6 +374,81 @@ def tearDownClass(cls): finally: super().tearDownClass() + def test_udf_binary_type(self): + def get_binary_type(x): + return type(x).__name__ + + binary_udf = udf(get_binary_type, returnType="string", useArrow=True) + + df = self.spark.createDataFrame( + [Row(b=b"hello"), Row(b=b"world")], schema=StructType([StructField("b", BinaryType())]) + ) + expected = self.spark.createDataFrame([Row(type_name="bytes"), Row(type_name="bytes")]) + # For Arrow Python UDF with legacy conversion BinaryType is always mapped to bytes + for conf_val in ["true", "false"]: + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_val}): + result = df.select(binary_udf(col("b")).alias("type_name")) + assertDataFrameEqual(result, expected) + + def test_udf_binary_type_in_nested_structures(self): + # For Arrow Python UDF with legacy conversion BinaryType is always mapped to bytes + # Test binary in array + def check_array_binary_type(arr): + return type(arr[0]).__name__ + + array_udf = udf(check_array_binary_type, returnType="string") + df_array = self.spark.createDataFrame( + [Row(arr=[b"hello", b"world"])], + schema=StructType([StructField("arr", ArrayType(BinaryType()))]), + ) + expected = self.spark.createDataFrame([Row(type_name="bytes")]) + for conf_val in ["true", "false"]: + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_val}): + result = df_array.select(array_udf(col("arr")).alias("type_name")) + assertDataFrameEqual(result, expected) + + # Test binary in map value + def check_map_binary_type(m): + return type(list(m.values())[0]).__name__ + + map_udf = udf(check_map_binary_type, returnType="string") + df_map = self.spark.createDataFrame( + [Row(m={"key": b"value"})], + schema=StructType([StructField("m", MapType(StringType(), BinaryType()))]), + ) + expected = self.spark.createDataFrame([Row(type_name="bytes")]) + for conf_val in ["true", "false"]: + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_val}): + result = df_map.select(map_udf(col("m")).alias("type_name")) + assertDataFrameEqual(result, expected) + + # Test binary in struct + def check_struct_binary_type(s): + return type(s.binary_field).__name__ + + struct_udf = udf(check_struct_binary_type, returnType="string") + df_struct = self.spark.createDataFrame( + [Row(s=Row(binary_field=b"test", other_field="value"))], + schema=StructType( + [ + StructField( + "s", + StructType( + [ + StructField("binary_field", BinaryType()), + StructField("other_field", StringType()), + ] + ), + ) + ] + ), + ) + expected = self.spark.createDataFrame([Row(type_name="bytes")]) + for conf_val in ["true", "false"]: + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_val}): + result = df_struct.select(struct_udf(col("s")).alias("type_name")) + assertDataFrameEqual(result, expected) + class ArrowPythonUDFNonLegacyTestsMixin(ArrowPythonUDFTestsMixin): @classmethod diff --git a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py index 6af55bb0b0bd1..0f879087e7bb2 100644 --- a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py +++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py @@ -15,6 +15,8 @@ # limitations under the License. # +import unittest + from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests from pyspark.sql.tests.arrow.test_arrow_python_udf import ArrowPythonUDFTestsMixin @@ -46,6 +48,14 @@ def tearDownClass(cls): finally: super().tearDownClass() + @unittest.skip("Duplicate test as it is already tested in ArrowPythonUDFLegacyTests.") + def test_udf_binary_type(self): + super().test_udf_binary_type(self) + + @unittest.skip("Duplicate test as it is already tested in ArrowPythonUDFLegacyTests.") + def test_udf_binary_type_in_nested_structures(self): + super().test_udf_binary_type_in_nested_structures(self) + class ArrowPythonUDFParityNonLegacyTestsMixin(ArrowPythonUDFTestsMixin): @classmethod @@ -62,6 +72,14 @@ def tearDownClass(cls): finally: super().tearDownClass() + @unittest.skip("Duplicate test as it is already tested in ArrowPythonUDFNonLegacyTests.") + def test_udf_binary_type(self): + super().test_udf_binary_type(self) + + @unittest.skip("Duplicate test as it is already tested in ArrowPythonUDFLegacyTests.") + def test_udf_binary_type_in_nested_structures(self): + super().test_udf_binary_type_in_nested_structures(self) + class ArrowPythonUDFParityLegacyTests(UDFParityTests, ArrowPythonUDFParityLegacyTestsMixin): @classmethod diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index b0a7f3f381274..c189f996cbe43 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -163,6 +163,20 @@ def Interrupt(self, req: proto.InterruptRequest, metadata): resp.session_id = self._session_id return resp + def Config(self, req: proto.ConfigRequest, metadata): + self.req = req + resp = proto.ConfigResponse() + resp.session_id = self._session_id + if req.operation.HasField("get"): + pair = resp.pairs.add() + pair.key = req.operation.get.keys[0] + pair.value = "true" # Default value + elif req.operation.HasField("get_with_default"): + pair = resp.pairs.add() + pair.key = req.operation.get_with_default.pairs[0].key + pair.value = req.operation.get_with_default.pairs[0].value or "true" + return resp + # The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster) # and it blocks the test process exiting because it is registered as the atexit handler # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. diff --git a/python/pyspark/sql/tests/connect/test_connect_collection.py b/python/pyspark/sql/tests/connect/test_connect_collection.py index 61932c38733b7..2343e8b9cde4c 100644 --- a/python/pyspark/sql/tests/connect/test_connect_collection.py +++ b/python/pyspark/sql/tests/connect/test_connect_collection.py @@ -291,6 +291,85 @@ def test_collect_nested_type(self): ).collect(), ) + def test_collect_binary_type(self): + """Test that df.collect() respects binary_as_bytes configuration for server-side data""" + query = """ + SELECT * FROM VALUES + (CAST('hello' AS BINARY)), + (CAST('world' AS BINARY)) + AS tab(b) + """ + + for conf_value in ["true", "false"]: + expected_type = bytes if conf_value == "true" else bytearray + with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + connect_rows = self.connect.sql(query).collect() + self.assertEqual(len(connect_rows), 2) + for row in connect_rows: + self.assertIsInstance(row.b, expected_type) + + spark_rows = self.spark.sql(query).collect() + self.assertEqual(len(spark_rows), 2) + for row in spark_rows: + self.assertIsInstance(row.b, expected_type) + + def test_to_local_iterator_binary_type(self): + """Test that df.toLocalIterator() respects binary_as_bytes configuration""" + query = """ + SELECT * FROM VALUES + (CAST('data1' AS BINARY)), + (CAST('data2' AS BINARY)) + AS tab(b) + """ + + for conf_value in ["true", "false"]: + expected_type = bytes if conf_value == "true" else bytearray + with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + connect_count = 0 + for row in self.connect.sql(query).toLocalIterator(): + self.assertIsInstance(row.b, expected_type) + connect_count += 1 + self.assertEqual(connect_count, 2) + + spark_count = 0 + for row in self.spark.sql(query).toLocalIterator(): + self.assertIsInstance(row.b, expected_type) + spark_count += 1 + self.assertEqual(spark_count, 2) + + def test_foreach_partition_binary_type(self): + """Test that df.foreachPartition() respects binary_as_bytes configuration + + Since foreachPartition() runs on executors and cannot return data to the driver, + we test by ensuring the function doesn't throw exceptions when it expects the correct types. + """ + query = """ + SELECT * FROM VALUES + (CAST('partition1' AS BINARY)), + (CAST('partition2' AS BINARY)) + AS tab(b) + """ + + for conf_value in ["true", "false"]: + expected_type = bytes if conf_value == "true" else bytearray + expected_type_name = "bytes" if conf_value == "true" else "bytearray" + + with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + + def assert_type(iterator): + count = 0 + for row in iterator: + # This will raise an exception if the type is not as expected + assert isinstance( + row.b, expected_type + ), f"Expected {expected_type_name}, got {type(row.b).__name__}" + count += 1 + # Ensure we actually processed rows + assert count > 0, "No rows were processed" + + self.connect.sql(query).foreachPartition(assert_type) + self.spark.sql(query).foreachPartition(assert_type) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_collection import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index 2b18fe8d04d7a..ca3b6f6671aa5 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -109,6 +109,43 @@ def test_conversion(self): with self.subTest(expected=e): self.assertEqual(a, e) + def test_binary_as_bytes_conversion(self): + data = [ + ( + str(i).encode(), # simple binary + [str(j).encode() for j in range(3)], # array of binary + {str(j): str(j).encode() for j in range(2)}, # map with binary values + {"b": str(i).encode()}, # struct with binary + ) + for i in range(2) + ] + schema = ( + StructType() + .add("b", BinaryType()) + .add("arr_b", ArrayType(BinaryType())) + .add("map_b", MapType(StringType(), BinaryType())) + .add("struct_b", StructType().add("b", BinaryType())) + ) + + tbl = LocalDataToArrowConversion.convert(data, schema, use_large_var_types=False) + + for binary_as_bytes, expected_type in [(True, bytes), (False, bytearray)]: + actual = ArrowTableToRowsConversion.convert( + tbl, schema, binary_as_bytes=binary_as_bytes + ) + + for row in actual: + # Simple binary field + self.assertIsInstance(row.b, expected_type) + # Array elements + for elem in row.arr_b: + self.assertIsInstance(elem, expected_type) + # Map values + for value in row.map_b.values(): + self.assertIsInstance(value, expected_type) + # Struct field + self.assertIsInstance(row.struct_b.b, expected_type) + if __name__ == "__main__": from pyspark.sql.tests.test_conversion import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 1d1cc3507f0e6..33ba2c02c63cb 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -20,9 +20,17 @@ import tempfile from pyspark.errors import AnalysisException +from pyspark.sql import Row from pyspark.sql.functions import col, lit from pyspark.sql.readwriter import DataFrameWriterV2 -from pyspark.sql.types import StructType, StructField, StringType +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + BinaryType, + ArrayType, + MapType, +) from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -238,6 +246,46 @@ def test_cached_table(self): self.assertEqual(join2.columns, ["id", "value_1", "index", "value_2"]) + def test_binary_type(self): + """Test that binary type in data sources respects binaryAsBytes config""" + schema = StructType( + [ + StructField("id", StringType()), + StructField("bin", BinaryType()), + StructField("arr_bin", ArrayType(BinaryType())), + StructField("map_bin", MapType(StringType(), BinaryType())), + ] + ) + # Create DataFrame with binary data (can use either bytes or bytearray) + data = [Row(id="1", bin=b"hello", arr_bin=[b"a"], map_bin={"key": b"value"})] + df = self.spark.createDataFrame(data, schema) + + tmpPath = tempfile.mkdtemp() + try: + # Write to parquet + df.write.mode("overwrite").parquet(tmpPath) + + for conf_value in ["true", "false"]: + expected_type = bytes if conf_value == "true" else bytearray + expected_bin = b"hello" if conf_value == "true" else bytearray(b"hello") + expected_arr = b"a" if conf_value == "true" else bytearray(b"a") + expected_map = b"value" if conf_value == "true" else bytearray(b"value") + + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + result = self.spark.read.parquet(tmpPath).collect() + row = result[0] + # Check binary field + self.assertIsInstance(row.bin, expected_type) + self.assertEqual(row.bin, expected_bin) + # Check array of binary + self.assertIsInstance(row.arr_bin[0], expected_type) + self.assertEqual(row.arr_bin[0], expected_arr) + # Check map value + self.assertIsInstance(row.map_bin["key"], expected_type) + self.assertEqual(row.map_bin["key"], expected_map) + finally: + shutil.rmtree(tmpPath) + # "[SPARK-51182]: DataFrameWriter should throw dataPathNotSpecifiedError when path is not # specified" def test_save(self): diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 9d130a7f525ac..37f651a96a846 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -32,6 +32,7 @@ from pyspark.sql.types import ( StringType, IntegerType, + BinaryType, BooleanType, DoubleType, LongType, @@ -1446,6 +1447,89 @@ def my_udf(): self.spark.range(1).select(my_udf().alias("result")).show() + def test_udf_binary_type(self): + def get_binary_type(x): + return type(x).__name__ + + binary_udf = udf(get_binary_type, returnType="string") + + df = self.spark.createDataFrame( + [Row(b=b"hello world")], schema=StructType([StructField("b", BinaryType())]) + ) + + for conf_value in ["true", "false"]: + expected_type = "bytes" if conf_value == "true" else "bytearray" + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + result = df.select(binary_udf(col("b")).alias("type_name")) + expected = self.spark.createDataFrame([Row(type_name=expected_type)]) + assertDataFrameEqual(result, expected) + + def test_udf_binary_type_in_nested_structures(self): + """Test that binary type in arrays, maps, and structs respects binaryAsBytes config""" + + # Test binary in array + def check_array_binary_type(arr): + return type(arr[0]).__name__ + + array_udf = udf(check_array_binary_type, returnType="string") + df_array = self.spark.createDataFrame( + [Row(arr=[b"hello world"])], + schema=StructType([StructField("arr", ArrayType(BinaryType()))]), + ) + + for conf_value in ["true", "false"]: + expected_type = "bytes" if conf_value == "true" else "bytearray" + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + result = df_array.select(array_udf(col("arr")).alias("type_name")) + expected = self.spark.createDataFrame([Row(type_name=expected_type)]) + assertDataFrameEqual(result, expected) + + # Test binary in map value + def check_map_binary_type(m): + return type(list(m.values())[0]).__name__ + + map_udf = udf(check_map_binary_type, returnType="string") + df_map = self.spark.createDataFrame( + [Row(m={"key": b"value"})], + schema=StructType([StructField("m", MapType(StringType(), BinaryType()))]), + ) + + for conf_value in ["true", "false"]: + expected_type = "bytes" if conf_value == "true" else "bytearray" + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + result = df_map.select(map_udf(col("m")).alias("type_name")) + expected = self.spark.createDataFrame([Row(type_name=expected_type)]) + assertDataFrameEqual(result, expected) + + # Test binary in struct + def check_struct_binary_type(s): + return type(s.binary_field).__name__ + + struct_udf = udf(check_struct_binary_type, returnType="string") + df_struct = self.spark.createDataFrame( + [Row(s=Row(binary_field=b"test", other_field="value"))], + schema=StructType( + [ + StructField( + "s", + StructType( + [ + StructField("binary_field", BinaryType()), + StructField("other_field", StringType()), + ] + ), + ) + ] + ), + ) + + for conf_value in ["true", "false"]: + expected_type = "bytes" if conf_value == "true" else "bytearray" + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + result = df_struct.select(struct_udf(col("s")).alias("type_name")) + expected = self.spark.createDataFrame([Row(type_name=expected_type)]) + assertDataFrameEqual(result, expected) + class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index b006ac6c14d4a..5cf6a0eba59c3 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -3044,6 +3044,19 @@ def eval(self, v1, v2, v3, v4): for idx, field in enumerate(result_df.schema.fields): self.assertEqual(field.dataType, expected_output_types[idx]) + def test_udtf_binary_type(self): + @udtf(returnType="type_name: string") + class BinaryTypeUDTF: + def eval(self, b): + # Check the type of the binary input and return type name as string + yield (type(b).__name__,) + + for conf_value in ["true", "false"]: + expected_type = "bytes" if conf_value == "true" else "bytearray" + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + result = BinaryTypeUDTF(lit(b"test")).collect() + self.assertEqual(result[0]["type_name"], expected_type) + class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): @classmethod @@ -3063,6 +3076,19 @@ def tearDownClass(cls): not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message ) class LegacyUDTFArrowTestsMixin(BaseUDTFTestsMixin): + def test_udtf_binary_type(self): + @udtf(returnType="type_name: string") + class BinaryTypeUDTF: + def eval(self, b): + # Check the type of the binary input and return type name as string + yield (type(b).__name__,) + + # For Arrow Python UDTF with legacy conversion BinaryType is always mapped to bytes + for conf_value in ["true", "false"]: + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}): + result = BinaryTypeUDTF(lit(b"test")).collect() + self.assertEqual(result[0]["type_name"], "bytes") + def test_eval_type(self): def upper(x: str): return upper(x) @@ -3389,6 +3415,11 @@ def tearDownClass(cls): class UDTFArrowTestsMixin(LegacyUDTFArrowTestsMixin): + def test_udtf_binary_type(self): + # For Arrow Python UDTF with non-legacy conversionBinaryType is mapped to + # bytes or bytearray consistently with non-Arrow Python UDTF behavior. + BaseUDTFTestsMixin.test_udtf_binary_type(self) + def test_numeric_output_type_casting(self): class TestUDTF: def eval(self): diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt b/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt index d21e7f2eb24a1..2b87384586237 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt @@ -17,8 +17,8 @@ |decimal_null |decimal(3,2) |[None, Decimal('9.99')] |['object', 'object'] |[None, Decimal('9.99')] | |string_values |string |['abc', '', 'hello'] |['object', 'object', 'object'] |['abc', '', 'hello'] | |string_null |string |[None, 'test'] |['object', 'object'] |[None, 'test'] | -|binary_values |binary |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] |['object', 'object', 'object'] |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] | -|binary_null |binary |[None, bytearray(b'test')] |['object', 'object'] |[None, bytearray(b'test')] | +|binary_values |binary |[b'abc', b'', b'ABC'] |['object', 'object', 'object'] |[b'abc', b'', b'ABC'] | +|binary_null |binary |[None, b'test'] |['object', 'object'] |[None, b'test'] | |boolean_values |boolean |[True, False] |['bool', 'bool'] |[True, False] | |boolean_null |boolean |[None, True] |['object', 'object'] |[None, True] | |date_values |date |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] |['object', 'object'] |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_return_type_coercion.txt b/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_return_type_coercion.txt index 7719d1805d9e9..052e930c2e1f2 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_return_type_coercion.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_return_type_coercion.txt @@ -12,7 +12,7 @@ |float |[None, None] |[1.0, 0.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |X |X |X |X |[12.0, 34.0] |X |[1.0, 2.0] |X |X |X |X |X |X | |double |[None, None] |[1.0, 0.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |X |X |X |X |[12.0, 34.0] |X |[1.0, 2.0] |X |X |X |X |X |X | |array |[None, None] |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |[[1, 2], [3, 4]] |[[1, 2, 3], [1, 2, 3]] |X |X |X |X |X |X |X | -|binary |[None, None] |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b'a'), bytearray(b' |[bytearray(b'12'), bytearray(b |X |X |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b'A'), bytearray(b' |X |X | +|binary |[None, None] |[b'\x01', b''] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'', b''] |[b'', b''] |[b'', b''] |[b'', b''] |[b'', b''] |[b'', b''] |[b'a', b'b'] |[b'12', b'34'] |X |X |[b'', b''] |[b'', b''] |[b'', b''] |[b'A', b'B'] |X |X | |decimal(10,0) |[None, None] |X |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |X |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |X |X |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |X |X |X |X |X |X |[Decimal('1'), Decimal('2')] |X |X |X |X |X |X | |map |[None, None] |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |[{'a': 1}, {'b': 2}] | |struct<_1:int> |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |[Row(_1=1), Row(_1=2)] |X | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_disabled.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_disabled.txt index 2572d48dbec7c..a3727dfd5d6b7 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_disabled.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_disabled.txt @@ -17,8 +17,8 @@ |decimal_null |decimal(3,2) |[None, Decimal('9.99')] |['NoneType', 'Decimal'] |['None', '9.99'] | |string_values |string |['abc', '', 'hello'] |['str', 'str', 'str'] |['abc', '', 'hello'] | |string_null |string |[None, 'test'] |['NoneType', 'str'] |['None', 'test'] | -|binary_values |binary |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] |['bytearray', 'bytearray', 'bytearray'] |["bytearray(b'abc')", "bytearray(b'')", "bytearray(b'ABC')"] | -|binary_null |binary |[None, bytearray(b'test')] |['NoneType', 'bytearray'] |['None', "bytearray(b'test')"] | +|binary_values |binary |[b'abc', b'', b'ABC'] |['bytes', 'bytes', 'bytes'] |["b'abc'", "b''", "b'ABC'"] | +|binary_null |binary |[None, b'test'] |['NoneType', 'bytes'] |['None', "b'test'"] | |boolean_values |boolean |[True, False] |['bool', 'bool'] |['True', 'False'] | |boolean_null |boolean |[None, True] |['NoneType', 'bool'] |['None', 'True'] | |date_values |date |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] |['date', 'date'] |['2020-02-02', '1970-01-01'] | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_enabled.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_enabled.txt index 2572d48dbec7c..a3727dfd5d6b7 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_enabled.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_enabled.txt @@ -17,8 +17,8 @@ |decimal_null |decimal(3,2) |[None, Decimal('9.99')] |['NoneType', 'Decimal'] |['None', '9.99'] | |string_values |string |['abc', '', 'hello'] |['str', 'str', 'str'] |['abc', '', 'hello'] | |string_null |string |[None, 'test'] |['NoneType', 'str'] |['None', 'test'] | -|binary_values |binary |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] |['bytearray', 'bytearray', 'bytearray'] |["bytearray(b'abc')", "bytearray(b'')", "bytearray(b'ABC')"] | -|binary_null |binary |[None, bytearray(b'test')] |['NoneType', 'bytearray'] |['None', "bytearray(b'test')"] | +|binary_values |binary |[b'abc', b'', b'ABC'] |['bytes', 'bytes', 'bytes'] |["b'abc'", "b''", "b'ABC'"] | +|binary_null |binary |[None, b'test'] |['NoneType', 'bytes'] |['None', "b'test'"] | |boolean_values |boolean |[True, False] |['bool', 'bool'] |['True', 'False'] | |boolean_null |boolean |[None, True] |['NoneType', 'bool'] |['None', 'True'] | |date_values |date |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] |['date', 'date'] |['2020-02-02', '1970-01-01'] | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_legacy_pandas.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_legacy_pandas.txt index 92f8af100e743..576af5f12102d 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_legacy_pandas.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_legacy_pandas.txt @@ -17,8 +17,8 @@ |decimal_null |decimal(3,2) |[None, Decimal('9.99')] |['NoneType', 'Decimal'] |['None', '9.99'] | |string_values |string |['abc', '', 'hello'] |['str', 'str', 'str'] |['abc', '', 'hello'] | |string_null |string |[None, 'test'] |['NoneType', 'str'] |['None', 'test'] | -|binary_values |binary |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] |['bytes', 'bytes', 'bytes'] |["b'abc'", "b''", "b'ABC'"] | -|binary_null |binary |[None, bytearray(b'test')] |['NoneType', 'bytes'] |['None', "b'test'"] | +|binary_values |binary |[b'abc', b'', b'ABC'] |['bytes', 'bytes', 'bytes'] |["b'abc'", "b''", "b'ABC'"] | +|binary_null |binary |[None, b'test'] |['NoneType', 'bytes'] |['None', "b'test'"] | |boolean_values |boolean |[True, False] |['bool', 'bool'] |['True', 'False'] | |boolean_null |boolean |[None, True] |['NoneType', 'bool'] |['None', 'True'] | |date_values |date |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] |['date', 'date'] |['2020-02-02', '1970-01-01'] | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_disabled.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_disabled.txt index 7f87b89a5fcf6..46dd597d926f4 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_disabled.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_disabled.txt @@ -12,7 +12,7 @@ |float |None |None |None |None |None |None |1.0 |None |None |None |None |None |None |X |X | |double |None |None |None |None |None |None |1.0 |None |None |None |None |None |None |X |X | |array |None |None |None |None |None |None |None |[1] |[1] |[1] |[65, 66, 67] |None |None |X |X | -|binary |None |None |None |bytearray(b'a') |None |None |None |None |None |None |bytearray(b'ABC') |None |None |X |X | +|binary |None |None |None |b'a' |None |None |None |None |None |None |b'ABC' |None |None |X |X | |decimal(10,0) |None |None |None |None |None |None |None |None |None |None |None |Decimal('1') |None |X |X | |map |None |None |None |None |None |None |None |None |None |None |None |None |{'a': 1} |X |X | |struct<_1:int> |None |X |X |X |X |X |X |X |Row(_1=1) |Row(_1=1) |X |X |Row(_1=None) |Row(_1=1) |Row(_1=1) | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_enabled.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_enabled.txt index c117113369e56..f08f03d6d7d71 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_enabled.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_enabled.txt @@ -12,7 +12,7 @@ |float |None |1.0 |1.0 |X |X |X |1.0 |X |X |X |X |1.0 |X |X |X | |double |None |1.0 |1.0 |X |X |X |1.0 |X |X |X |X |1.0 |X |X |X | |array |None |X |X |X |X |X |X |[1] |[1] |[1] |[65, 66, 67] |X |X |[1] |[1] | -|binary |None |X |X |X |X |X |X |X |X |X |bytearray(b'ABC') |X |X |X |X | +|binary |None |X |X |X |X |X |X |X |X |X |b'ABC' |X |X |X |X | |decimal(10,0) |None |X |X |X |X |X |X |X |X |X |X |Decimal('1') |X |X |X | |map |None |X |X |X |X |X |X |X |X |X |X |X |{'a': 1} |X |X | |struct<_1:int> |None |X |X |X |X |X |X |X |X |Row(_1=1) |X |X |Row(_1=None) |Row(_1=1) |Row(_1=1) | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_legacy_pandas.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_legacy_pandas.txt index a1809dfa9aab6..8b68abb412e58 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_legacy_pandas.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_legacy_pandas.txt @@ -12,7 +12,7 @@ |float |None |1.0 |1.0 |X |X |X |1.0 |X |X |X |X |1.0 |X |X |X | |double |None |1.0 |1.0 |X |X |X |1.0 |X |X |X |X |1.0 |X |X |X | |array |None |X |X |X |X |X |X |[1] |[1] |[1] |[65, 66, 67] |X |X |[1] |[1] | -|binary |None |bytearray(b'\x00') |bytearray(b'\x00') |X |X |X |X |bytearray(b'\x01\x00\x00\x00') |bytearray(b'\x01') |bytearray(b'\x01') |bytearray(b'ABC') |X |X |bytearray(b'\x01') |bytearray(b'\x01') | +|binary |None |b'\x00' |b'\x00' |X |X |X |X |b'\x01\x00\x00\x00' |b'\x01' |b'\x01' |b'ABC' |X |X |b'\x01' |b'\x01' | |decimal(10,0) |None |X |X |X |X |X |Decimal('1') |X |X |X |X |Decimal('1') |X |X |X | |map |None |X |X |X |X |X |X |X |X |X |X |X |{'a': 1} |X |X | |struct<_1:int> |None |X |X |X |X |X |X |Row(_1=1) |Row(_1=1) |Row(_1=1) |Row(_1=65) |X |Row(_1=None) |Row(_1=1) |Row(_1=1) | diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py b/python/pyspark/sql/worker/data_source_pushdown_filters.py index ac6f84e617150..b523cab7c49e7 100644 --- a/python/pyspark/sql/worker/data_source_pushdown_filters.py +++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py @@ -27,7 +27,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkAssertionError, PySparkValueError from pyspark.errors.exceptions.base import PySparkNotImplementedError -from pyspark.serializers import SpecialLengths, UTF8Deserializer, read_int, write_int +from pyspark.serializers import SpecialLengths, UTF8Deserializer, read_int, read_bool, write_int from pyspark.sql.datasource import ( DataSource, DataSourceReader, @@ -232,6 +232,7 @@ def main(infile: IO, outfile: IO) -> None: "The maximum arrow batch size should be greater than 0, but got " f"'{max_arrow_batch_size}'" ) + binary_as_bytes = read_bool(infile) # Return the read function and partitions. Doing this in the same worker as filter pushdown # helps reduce the number of Python worker calls. @@ -241,6 +242,7 @@ def main(infile: IO, outfile: IO) -> None: data_source=data_source, schema=schema, max_arrow_batch_size=max_arrow_batch_size, + binary_as_bytes=binary_as_bytes, ) # Return the supported filter indices. diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index 7c14ebfc53e4d..08d18acda78b9 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -175,11 +175,14 @@ def write_read_func_and_partitions( data_source: DataSource, schema: StructType, max_arrow_batch_size: int, + binary_as_bytes: bool, ) -> None: is_streaming = isinstance(reader, DataSourceStreamReader) # Create input converter. - converter = ArrowTableToRowsConversion._create_converter(BinaryType()) + converter = ArrowTableToRowsConversion._create_converter( + BinaryType(), none_on_identity=False, binary_as_bytes=binary_as_bytes + ) # Create output converter. return_type = schema @@ -352,6 +355,7 @@ def main(infile: IO, outfile: IO) -> None: enable_pushdown = read_bool(infile) is_streaming = read_bool(infile) + binary_as_bytes = read_bool(infile) # Instantiate data source reader. if is_streaming: @@ -390,6 +394,7 @@ def main(infile: IO, outfile: IO) -> None: data_source=data_source, schema=schema, max_arrow_batch_size=max_arrow_batch_size, + binary_as_bytes=binary_as_bytes, ) except BaseException as e: handle_worker_exception(e, outfile) diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index 3e772031225d5..d752a176bcbad 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -171,6 +171,7 @@ def main(infile: IO, outfile: IO) -> None: overwrite = read_bool(infile) is_streaming = read_bool(infile) + binary_as_bytes = read_bool(infile) # Instantiate a data source. data_source = data_source_cls(options=options) # type: ignore @@ -205,7 +206,10 @@ def main(infile: IO, outfile: IO) -> None: import pyarrow as pa converters = [ - ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields + ArrowTableToRowsConversion._create_converter( + f.dataType, none_on_identity=False, binary_as_bytes=binary_as_bytes + ) + for f in schema.fields ] fields = schema.fieldNames() diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 1f5c5a086abfb..b409bd12ae379 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -257,3 +257,35 @@ def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20): def test_assert_remote_mode(self): # no need to test this in mixed mode pass + + def connect_conf(self, conf_dict): + """Context manager to set configuration on Spark Connect session""" + + @contextlib.contextmanager + def _connect_conf(): + old_values = {} + for key, value in conf_dict.items(): + old_values[key] = self.connect.conf.get(key, None) + self.connect.conf.set(key, value) + try: + yield + finally: + for key, old_value in old_values.items(): + if old_value is None: + self.connect.conf.unset(key) + else: + self.connect.conf.set(key, old_value) + + return _connect_conf() + + def both_conf(self, conf_dict): + """Context manager to set configuration on both classic and Connect sessions""" + + @contextlib.contextmanager + def _both_conf(): + with contextlib.ExitStack() as stack: + stack.enter_context(self.sql_conf(conf_dict)) + stack.enter_context(self.connect_conf(conf_dict)) + yield + + return _both_conf() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 25eaf26243917..de00696b41149 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1288,6 +1288,7 @@ def use_legacy_pandas_udf_conversion(runner_conf): def read_udtf(pickleSer, infile, eval_type): prefers_large_var_types = False legacy_pandas_conversion = False + binary_as_bytes = True if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: runner_conf = {} @@ -1304,6 +1305,9 @@ def read_udtf(pickleSer, infile, eval_type): ).lower() == "true" ) + binary_as_bytes = ( + runner_conf.get("spark.sql.execution.pyspark.binaryAsBytes", "true").lower() == "true" + ) input_types = [ field.dataType for field in _parse_datatype_json_string(utf8_deserializer.loads(infile)) ] @@ -2248,7 +2252,9 @@ def evaluate(*args: list, num_rows=1): def mapper(_, it): try: converters = [ - ArrowTableToRowsConversion._create_converter(dt, none_on_identity=True) + ArrowTableToRowsConversion._create_converter( + dt, none_on_identity=True, binary_as_bytes=binary_as_bytes + ) for dt in input_types ] for a in it: @@ -2545,6 +2551,9 @@ def read_udfs(pickleSer, infile, eval_type): ).lower() == "true" ) + binary_as_bytes = ( + runner_conf.get("spark.sql.execution.pyspark.binaryAsBytes", "true").lower() == "true" + ) _assign_cols_by_name = assign_cols_by_name(runner_conf) if eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: @@ -2636,7 +2645,7 @@ def read_udfs(pickleSer, infile, eval_type): f.dataType for f in _parse_datatype_json_string(utf8_deserializer.loads(infile)) ] ser = ArrowBatchUDFSerializer( - timezone, safecheck, input_types, int_to_decimal_coercion_enabled + timezone, safecheck, input_types, int_to_decimal_coercion_enabled, binary_as_bytes ) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 17b8dd493cf80..80dd336f69bd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3724,6 +3724,19 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PYSPARK_BINARY_AS_BYTES = + buildConf("spark.sql.execution.pyspark.binaryAsBytes") + .doc("When true, BinaryType is mapped consistently to bytes in PySpark." + + "When false, matches the PySpark behavior before 4.1.0. Before 4.1.0, BinaryType is " + + "mapped to bytearray for input of regular UDF and UDTF without arrow optimization, " + + "regular UDF and UDTF with arrow optimization and without legacy pandas conversion, " + + "Dataframe APIs, and data source; BinaryType is mapped to bytes for " + + "input of regular UDF and UDTF with arrow optimization and legacy pandas conversion.") + .version("4.1.0") + .booleanConf + .createWithDefault(true) + + val ARROW_LOCAL_RELATION_THRESHOLD = buildConf("spark.sql.execution.arrow.localRelationThreshold") .doc( @@ -7101,6 +7114,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED) + def pysparkBinaryAsBytes: Boolean = getConf(PYSPARK_BINARY_AS_BYTES) + def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED) def pythonUDFProfiler: Option[String] = getConf(PYTHON_UDF_PROFILER) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 85f59c282ff55..471e376e1c22d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -117,7 +117,7 @@ private[sql] object PythonSQLUtils extends Logging { def toPyRow(row: Row): Array[Byte] = { assert(row.isInstanceOf[GenericRowWithSchema]) withInternalRowPickler(_.dumps(EvaluatePython.toJava( - CatalystTypeConverters.convertToCatalyst(row), row.schema))) + CatalystTypeConverters.convertToCatalyst(row), row.schema, SQLConf.get.pysparkBinaryAsBytes))) } def toJVMRow( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index b8ffa09dfa05c..65881cf5d03b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -2072,14 +2072,17 @@ class Dataset[T] private[sql]( */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val structType = schema // capture it for closure - val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) + val binaryAsBytes = sparkSession.sessionState.conf.pysparkBinaryAsBytes // capture config value + val rdd = queryExecution.toRdd.map(row => + EvaluatePython.toJava(row, structType, binaryAsBytes)) EvaluatePython.javaToPython(rdd) } private[sql] def collectToPython(): Array[Any] = { EvaluatePython.registerPicklers() withAction("collectToPython", queryExecution) { plan => - val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) + val binaryAsBytes = sparkSession.sessionState.conf.pysparkBinaryAsBytes + val toJava: (Any) => Any = EvaluatePython.toJava(_, schema, binaryAsBytes) val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( plan.executeCollect().iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-DataFrame") @@ -2089,7 +2092,8 @@ class Dataset[T] private[sql]( private[sql] def tailToPython(n: Int): Array[Any] = { EvaluatePython.registerPicklers() withAction("tailToPython", queryExecution) { plan => - val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) + val binaryAsBytes = sparkSession.sessionState.conf.pysparkBinaryAsBytes + val toJava: (Any) => Any = EvaluatePython.toJava(_, schema, binaryAsBytes) val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( plan.executeTail(n).iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-DataFrame") @@ -2102,7 +2106,9 @@ class Dataset[T] private[sql]( EvaluatePython.registerPicklers() val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1) val rows = getRows(numRows, truncate).map(_.toArray).toArray - val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) + val binaryAsBytes = sparkSession.sessionState.conf.pysparkBinaryAsBytes + val toJava: (Any) => Any = + EvaluatePython.toJava(_, ArrayType(ArrayType(StringType)), binaryAsBytes) val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( rows.iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-GetRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala index 14aeba92dafe1..26bd5368e6f9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala @@ -458,6 +458,7 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner( // Send configurations dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch) + dataOut.writeBoolean(SQLConf.get.pysparkBinaryAsBytes) } override protected def receiveFromPython(dataIn: DataInputStream): PythonFilterPushdownResult = { @@ -550,6 +551,7 @@ private class UserDefinedPythonDataSourceReadRunner( dataOut.writeBoolean(SQLConf.get.pythonFilterPushDown) dataOut.writeBoolean(isStreaming) + dataOut.writeBoolean(SQLConf.get.pysparkBinaryAsBytes) } override protected def receiveFromPython(dataIn: DataInputStream): PythonDataSourceReadInfo = { @@ -600,6 +602,7 @@ private class UserDefinedPythonDataSourceWriteRunner( dataOut.writeBoolean(overwrite) dataOut.writeBoolean(isStreaming) + dataOut.writeBoolean(SQLConf.get.pysparkBinaryAsBytes) } override protected def receiveFromPython( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 609fa218f1288..77aec2a35f21d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -155,9 +155,12 @@ object ArrowPythonRunner { val intToDecimalCoercion = Seq( SQLConf.PYTHON_UDF_PANDAS_INT_TO_DECIMAL_COERCION_ENABLED.key -> conf.getConf(SQLConf.PYTHON_UDF_PANDAS_INT_TO_DECIMAL_COERCION_ENABLED, false).toString) + val binaryAsBytes = Seq( + SQLConf.PYSPARK_BINARY_AS_BYTES.key -> + conf.pysparkBinaryAsBytes.toString) Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++ arrowAyncParallelism ++ useLargeVarTypes ++ - intToDecimalCoercion ++ + intToDecimalCoercion ++ binaryAsBytes ++ legacyPandasConversion ++ legacyPandasConversionUDF: _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 28318a319b088..866719122ec4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -41,6 +41,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] override protected def evaluatorFactory: EvalPythonEvaluatorFactory = { val batchSize = conf.getConf(SQLConf.PYTHON_UDF_MAX_RECORDS_PER_BATCH) + val binaryAsBytes = conf.pysparkBinaryAsBytes new BatchEvalPythonEvaluatorFactory( child.output, udfs, @@ -48,7 +49,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] batchSize, pythonMetrics, jobArtifactUUID, - conf.pythonUDFProfiler) + conf.pythonUDFProfiler, + binaryAsBytes) } override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonExec = @@ -62,7 +64,8 @@ class BatchEvalPythonEvaluatorFactory( batchSize: Int, pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], - profiler: Option[String]) + profiler: Option[String], + binaryAsBytes: Boolean) extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { override def evaluate( @@ -74,7 +77,7 @@ class BatchEvalPythonEvaluatorFactory( EvaluatePython.registerPicklers() // register pickler for Row // Input iterator to Python. - val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema, batchSize) + val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema, batchSize, binaryAsBytes) // Output iterator for results from Python. val outputIterator = @@ -112,7 +115,8 @@ object BatchEvalPythonExec { def getInputIterator( iter: Iterator[InternalRow], schema: StructType, - batchSize: Int): Iterator[Array[Byte]] = { + batchSize: Int, + binaryAsBytes: Boolean): Iterator[Array[Byte]] = { val dataTypes = schema.map(_.dataType) val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) @@ -133,14 +137,14 @@ object BatchEvalPythonExec { // For each row, add it to the queue. iter.map { row => if (needConversion) { - EvaluatePython.toJava(row, schema) + EvaluatePython.toJava(row, schema, binaryAsBytes) } else { // fast path for these types that does not need conversion in Python val fields = new Array[Any](row.numFields) var i = 0 while (i < row.numFields) { val dt = dataTypes(i) - fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt, binaryAsBytes) i += 1 } fields diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index a1358c9cd7746..7b46ab4bd34aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -65,7 +65,8 @@ case class BatchEvalPythonUDTFExec( // Input iterator to Python. // For Python UDTF, we don't have a separate configuration for the batch size yet. - val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema, 100) + val inputIterator = BatchEvalPythonExec.getInputIterator( + iter, schema, 100, conf.pysparkBinaryAsBytes) // Output iterator for results from Python. val outputIterator = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 5d117a67e6bee..212cc5db124ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -35,6 +35,12 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} object EvaluatePython { + /** + * Wrapper class for byte arrays that should be pickled as Python bytes instead of bytearray. + * This is a marker class that tells the pickler to use bytes() constructor. + */ + private[python] class BytesWrapper(val data: Array[Byte]) + def needConversionInPython(dt: DataType): Boolean = dt match { case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType | _: TimeType => true @@ -49,39 +55,52 @@ object EvaluatePython { /** * Helper for converting from Catalyst type to java type suitable for Pickle. */ - def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (row: InternalRow, struct: StructType) => - val values = new Array[Any](row.numFields) - var i = 0 - while (i < row.numFields) { - values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) - i += 1 - } - new GenericRowWithSchema(values, struct) + def toJava( + obj: Any, + dataType: DataType, + binaryAsBytes: Boolean): Any = { + (obj, dataType) match { + case (null, _) => null + + case (row: InternalRow, struct: StructType) => + val values = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val field = struct.fields(i) + values(i) = toJava(row.get(i, field.dataType), field.dataType, binaryAsBytes) + i += 1 + } + new GenericRowWithSchema(values, struct) - case (a: ArrayData, array: ArrayType) => - val values = new java.util.ArrayList[Any](a.numElements()) - a.foreach(array.elementType, (_, e) => { - values.add(toJava(e, array.elementType)) - }) - values + case (a: ArrayData, array: ArrayType) => + val values = new java.util.ArrayList[Any](a.numElements()) + a.foreach(array.elementType, (_, e) => { + values.add(toJava(e, array.elementType, binaryAsBytes)) + }) + values - case (map: MapData, mt: MapType) => - val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(mt.keyType, mt.valueType, (k, v) => { - jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType)) - }) - jmap + case (map: MapData, mt: MapType) => + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + map.foreach(mt.keyType, mt.valueType, (k, v) => { + jmap.put(toJava(k, mt.keyType, binaryAsBytes), toJava(v, mt.valueType, binaryAsBytes)) + }) + jmap + + case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType, binaryAsBytes) - case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) + case (d: Decimal, _) => d.toJavaBigDecimal - case (d: Decimal, _) => d.toJavaBigDecimal + case (s: UTF8String, _: StringType) => s.toString - case (s: UTF8String, _: StringType) => s.toString + case (bytes: Array[Byte], BinaryType) => + if (binaryAsBytes) { + new BytesWrapper(bytes) + } else { + bytes + } - case (other, _) => other + case (other, _) => other + } } /** @@ -248,6 +267,37 @@ object EvaluatePython { } } + /** + * Pickler for BytesWrapper that pickles byte arrays as Python bytes using bytes() builtin. + * Structure: bytes(bytearray_data) where bytearray_data is pickled by razorvine's + * default pickler. + */ + private class BytesWrapperPickler extends IObjectPickler { + + private val cls = classOf[BytesWrapper] + + def register(): Unit = { + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + // Pickle structure: bytes(bytearray_value) + // GLOBAL 'builtins' 'bytes' + out.write(Opcodes.GLOBAL) + out.write("builtins\nbytes\n".getBytes(StandardCharsets.UTF_8)) + + // Pickle the wrapped byte array data using razorvine's built-in pickler + val wrapper = obj.asInstanceOf[BytesWrapper] + pickler.save(wrapper.data) + + // TUPLE1 creates a 1-tuple: (bytearray_value,) + out.write(Opcodes.TUPLE1) + + // REDUCE calls bytes(bytearray_value) + out.write(Opcodes.REDUCE) + } + } + /** * Pickler for external row. */ @@ -299,6 +349,7 @@ object EvaluatePython { SerDeUtil.initialize() new StructTypePickler().register() new RowPickler().register() + new BytesWrapperPickler().register() registered = true } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 037b784c9dd18..8987d6e36ff96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -231,7 +231,8 @@ class UserDefinedPythonTableFunctionAnalyzeRunner( } if (value.foldable) { dataOut.writeBoolean(true) - val obj = pickler.dumps(EvaluatePython.toJava(value.eval(), value.dataType)) + val obj = pickler.dumps(EvaluatePython.toJava( + value.eval(), value.dataType, SQLConf.get.pysparkBinaryAsBytes)) PythonWorkerUtils.writeBytes(obj, dataOut) } else { dataOut.writeBoolean(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala index b1b79946c2fba..37716d2d8413b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala @@ -89,7 +89,9 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) private lazy val inputByteIterator = { EvaluatePython.registerPicklers() - val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } + val objIterator = inputRowIterator.map { row => + EvaluatePython.toJava(row, schema, SQLConf.get.pysparkBinaryAsBytes) + } new SerDeUtil.AutoBatchedPickler(objIterator) }