Skip to content

Commit acf186d

Browse files
committed
#731 Add an option to copy data type when copying metadata.
1 parent f9be3c7 commit acf186d

File tree

2 files changed

+127
-5
lines changed

2 files changed

+127
-5
lines changed

spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,21 @@ object SparkUtils extends Logging {
4848
allExecutors.filter(!_.equals(driverHost)).toList.distinct
4949
}
5050

51+
/**
52+
* Returns true if Spark Data type is a primitive data type.
53+
*
54+
* @param dataType Stark data type
55+
* @return true if the data type is primitive.
56+
*/
57+
def isPrimitive(dataType: DataType): Boolean = {
58+
dataType match {
59+
case _: ArrayType => false
60+
case _: StructType => false
61+
case _: MapType => false
62+
case _ => true
63+
}
64+
}
65+
5166
/**
5267
* Given an instance of DataFrame returns a dataframe with flattened schema.
5368
* All nested structures are flattened and arrays are projected as columns.
@@ -248,12 +263,14 @@ object SparkUtils extends Logging {
248263
* @param schemaTo Schema to copy metadata to.
249264
* @param overwrite If true, the metadata of schemaTo is not retained
250265
* @param sourcePreferred If true, schemaFrom metadata is used on conflicts, schemaTo otherwise.
266+
* @param copyDataType If true, data type is copied as well. This is limited to primitive data types.
251267
* @return Same schema as schemaTo with metadata from schemaFrom.
252268
*/
253269
def copyMetadata(schemaFrom: StructType,
254270
schemaTo: StructType,
255271
overwrite: Boolean = false,
256-
sourcePreferred: Boolean = false): StructType = {
272+
sourcePreferred: Boolean = false,
273+
copyDataType: Boolean = false): StructType = {
257274
def joinMetadata(from: Metadata, to: Metadata): Metadata = {
258275
val newMetadataMerged = new MetadataBuilder
259276

@@ -273,12 +290,16 @@ object SparkUtils extends Logging {
273290
ar.elementType match {
274291
case st: StructType if fieldFrom.dataType.isInstanceOf[ArrayType] && fieldFrom.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType] =>
275292
val innerStructFrom = fieldFrom.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]
276-
val newDataType = StructType(copyMetadata(innerStructFrom, st).fields)
293+
val newDataType = StructType(copyMetadata(innerStructFrom, st, overwrite, sourcePreferred, copyDataType).fields)
277294
ArrayType(newDataType, ar.containsNull)
278295
case at: ArrayType =>
279296
processArray(at, fieldFrom, fieldTo)
280297
case p =>
281-
ArrayType(p, ar.containsNull)
298+
if (copyDataType && fieldFrom.dataType.isInstanceOf[ArrayType] && isPrimitive(fieldFrom.dataType.asInstanceOf[ArrayType].elementType)) {
299+
ArrayType(fieldFrom.dataType.asInstanceOf[ArrayType].elementType, ar.containsNull)
300+
} else {
301+
ArrayType(p, ar.containsNull)
302+
}
282303
}
283304
}
284305

@@ -295,13 +316,17 @@ object SparkUtils extends Logging {
295316

296317
fieldTo.dataType match {
297318
case st: StructType if fieldFrom.dataType.isInstanceOf[StructType] =>
298-
val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st).fields)
319+
val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st, overwrite, sourcePreferred, copyDataType).fields)
299320
fieldTo.copy(dataType = newDataType, metadata = newMetadata)
300321
case at: ArrayType =>
301322
val newType = processArray(at, fieldFrom, fieldTo)
302323
fieldTo.copy(dataType = newType, metadata = newMetadata)
303324
case _ =>
304-
fieldTo.copy(metadata = newMetadata)
325+
if (copyDataType && isPrimitive(fieldFrom.dataType)) {
326+
fieldTo.copy(dataType = fieldFrom.dataType, metadata = newMetadata)
327+
} else {
328+
fieldTo.copy(metadata = newMetadata)
329+
}
305330
}
306331
case None =>
307332
fieldTo

spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
3939
"""[{"id":4,"legs":[]}]""" ::
4040
"""[{"id":5,"legs":null}]""" :: Nil
4141

42+
test("IsPrimitive should work as expected") {
43+
assert(SparkUtils.isPrimitive(BooleanType))
44+
assert(SparkUtils.isPrimitive(ByteType))
45+
assert(SparkUtils.isPrimitive(ShortType))
46+
assert(SparkUtils.isPrimitive(IntegerType))
47+
assert(SparkUtils.isPrimitive(LongType))
48+
assert(SparkUtils.isPrimitive(FloatType))
49+
assert(SparkUtils.isPrimitive(DoubleType))
50+
assert(SparkUtils.isPrimitive(DecimalType(10, 2)))
51+
assert(SparkUtils.isPrimitive(StringType))
52+
assert(SparkUtils.isPrimitive(BinaryType))
53+
assert(SparkUtils.isPrimitive(DateType))
54+
assert(SparkUtils.isPrimitive(TimestampType))
55+
assert(!SparkUtils.isPrimitive(ArrayType(StringType)))
56+
assert(!SparkUtils.isPrimitive(StructType(Seq(StructField("a", StringType)))))
57+
assert(!SparkUtils.isPrimitive(MapType(StringType, StringType)))
58+
}
59+
4260
test("Test schema flattening of multiple nested structure") {
4361
val expectedOrigSchema =
4462
"""root
@@ -626,6 +644,85 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
626644
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120)
627645
}
628646

647+
test("copyMetadata should copy primitive data types when it is enabled") {
648+
val schemaFrom = StructType(
649+
Seq(
650+
StructField("int_field1", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test1").build()),
651+
StructField("string_field", StringType, nullable = true, metadata = new MetadataBuilder().putLong("maxLength", 120).build()),
652+
StructField("int_field2", StructType(
653+
Seq(
654+
StructField("int_field20", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test20").build())
655+
)
656+
), nullable = true),
657+
StructField("struct_field2", StructType(
658+
Seq(
659+
StructField("int_field3", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test3").build())
660+
)
661+
), nullable = true),
662+
StructField("array_string", ArrayType(StringType), nullable = true, metadata = new MetadataBuilder().putLong("maxLength", 60).build()),
663+
StructField("array_struct", ArrayType(StructType(
664+
Seq(
665+
StructField("int_field4", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test4").build())
666+
)
667+
)), nullable = true)
668+
)
669+
)
670+
671+
val schemaTo = StructType(
672+
Seq(
673+
StructField("int_field1", BooleanType, nullable = true),
674+
StructField("string_field", IntegerType, nullable = true),
675+
StructField("int_field2", IntegerType, nullable = true),
676+
StructField("struct_field2", StructType(
677+
Seq(
678+
StructField("int_field3", BooleanType, nullable = true)
679+
)
680+
), nullable = true),
681+
StructField("array_string", ArrayType(IntegerType), nullable = true),
682+
StructField("array_struct", ArrayType(StructType(
683+
Seq(
684+
StructField("int_field4", StringType, nullable = true)
685+
)
686+
)), nullable = true)
687+
)
688+
)
689+
690+
val schemaWithMetadata = SparkUtils.copyMetadata(schemaFrom, schemaTo, copyDataType = true)
691+
val fields = schemaWithMetadata.fields
692+
693+
// Ensure data types are copied
694+
// Expected schema:
695+
// root
696+
// |-- int_field1: boolean (nullable = true)
697+
// |-- string_field: integer (nullable = true)
698+
// |-- int_field2: integer (nullable = true)
699+
// |-- struct_field2: struct (nullable = true)
700+
// | |-- int_field3: boolean (nullable = true)
701+
// |-- array_string: array (nullable = true)
702+
// | |-- element: integer (containsNull = true)
703+
// |-- array_struct: array (nullable = true)
704+
// | |-- element: struct (containsNull = true)
705+
// | | |-- int_field4: string (nullable = true)
706+
assert(fields.head.dataType == IntegerType)
707+
assert(fields(1).dataType == StringType)
708+
assert(fields(2).dataType == IntegerType)
709+
assert(fields(3).dataType.isInstanceOf[StructType])
710+
assert(fields(4).dataType.isInstanceOf[ArrayType])
711+
assert(fields(5).dataType.isInstanceOf[ArrayType])
712+
713+
assert(fields(3).dataType.asInstanceOf[StructType].fields.head.dataType == IntegerType)
714+
assert(fields(4).dataType.asInstanceOf[ArrayType].elementType == StringType)
715+
assert(fields(5).dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType])
716+
assert(fields(5).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.dataType == IntegerType)
717+
718+
// Ensure metadata is copied
719+
assert(fields.head.metadata.getString("comment") == "Test1")
720+
assert(fields(1).metadata.getLong("maxLength") == 120)
721+
assert(fields(3).dataType.asInstanceOf[StructType].fields.head.metadata.getString("comment") == "Test3")
722+
assert(fields(4).metadata.getLong("maxLength") == 60)
723+
assert(fields(5).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.getString("comment") == "Test4")
724+
}
725+
629726
test("copyMetadata should retain metadata on conflicts by default") {
630727
val df1 = List(1, 2, 3).toDF("col1")
631728
val df2 = List(1, 2, 3).toDF("col1")

0 commit comments

Comments
 (0)