@@ -39,6 +39,24 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
39
39
""" [{"id":4,"legs":[]}]""" ::
40
40
""" [{"id":5,"legs":null}]""" :: Nil
41
41
42
+ test(" IsPrimitive should work as expected" ) {
43
+ assert(SparkUtils .isPrimitive(BooleanType ))
44
+ assert(SparkUtils .isPrimitive(ByteType ))
45
+ assert(SparkUtils .isPrimitive(ShortType ))
46
+ assert(SparkUtils .isPrimitive(IntegerType ))
47
+ assert(SparkUtils .isPrimitive(LongType ))
48
+ assert(SparkUtils .isPrimitive(FloatType ))
49
+ assert(SparkUtils .isPrimitive(DoubleType ))
50
+ assert(SparkUtils .isPrimitive(DecimalType (10 , 2 )))
51
+ assert(SparkUtils .isPrimitive(StringType ))
52
+ assert(SparkUtils .isPrimitive(BinaryType ))
53
+ assert(SparkUtils .isPrimitive(DateType ))
54
+ assert(SparkUtils .isPrimitive(TimestampType ))
55
+ assert(! SparkUtils .isPrimitive(ArrayType (StringType )))
56
+ assert(! SparkUtils .isPrimitive(StructType (Seq (StructField (" a" , StringType )))))
57
+ assert(! SparkUtils .isPrimitive(MapType (StringType , StringType )))
58
+ }
59
+
42
60
test(" Test schema flattening of multiple nested structure" ) {
43
61
val expectedOrigSchema =
44
62
""" root
@@ -626,6 +644,85 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
626
644
assert(newDf.schema.fields.head.metadata.getLong(" maxLength" ) == 120 )
627
645
}
628
646
647
+ test(" copyMetadata should copy primitive data types when it is enabled" ) {
648
+ val schemaFrom = StructType (
649
+ Seq (
650
+ StructField (" int_field1" , IntegerType , nullable = true , metadata = new MetadataBuilder ().putString(" comment" , " Test1" ).build()),
651
+ StructField (" string_field" , StringType , nullable = true , metadata = new MetadataBuilder ().putLong(" maxLength" , 120 ).build()),
652
+ StructField (" int_field2" , StructType (
653
+ Seq (
654
+ StructField (" int_field20" , IntegerType , nullable = true , metadata = new MetadataBuilder ().putString(" comment" , " Test20" ).build())
655
+ )
656
+ ), nullable = true ),
657
+ StructField (" struct_field2" , StructType (
658
+ Seq (
659
+ StructField (" int_field3" , IntegerType , nullable = true , metadata = new MetadataBuilder ().putString(" comment" , " Test3" ).build())
660
+ )
661
+ ), nullable = true ),
662
+ StructField (" array_string" , ArrayType (StringType ), nullable = true , metadata = new MetadataBuilder ().putLong(" maxLength" , 60 ).build()),
663
+ StructField (" array_struct" , ArrayType (StructType (
664
+ Seq (
665
+ StructField (" int_field4" , IntegerType , nullable = true , metadata = new MetadataBuilder ().putString(" comment" , " Test4" ).build())
666
+ )
667
+ )), nullable = true )
668
+ )
669
+ )
670
+
671
+ val schemaTo = StructType (
672
+ Seq (
673
+ StructField (" int_field1" , BooleanType , nullable = true ),
674
+ StructField (" string_field" , IntegerType , nullable = true ),
675
+ StructField (" int_field2" , IntegerType , nullable = true ),
676
+ StructField (" struct_field2" , StructType (
677
+ Seq (
678
+ StructField (" int_field3" , BooleanType , nullable = true )
679
+ )
680
+ ), nullable = true ),
681
+ StructField (" array_string" , ArrayType (IntegerType ), nullable = true ),
682
+ StructField (" array_struct" , ArrayType (StructType (
683
+ Seq (
684
+ StructField (" int_field4" , StringType , nullable = true )
685
+ )
686
+ )), nullable = true )
687
+ )
688
+ )
689
+
690
+ val schemaWithMetadata = SparkUtils .copyMetadata(schemaFrom, schemaTo, copyDataType = true )
691
+ val fields = schemaWithMetadata.fields
692
+
693
+ // Ensure data types are copied
694
+ // Expected schema:
695
+ // root
696
+ // |-- int_field1: boolean (nullable = true)
697
+ // |-- string_field: integer (nullable = true)
698
+ // |-- int_field2: integer (nullable = true)
699
+ // |-- struct_field2: struct (nullable = true)
700
+ // | |-- int_field3: boolean (nullable = true)
701
+ // |-- array_string: array (nullable = true)
702
+ // | |-- element: integer (containsNull = true)
703
+ // |-- array_struct: array (nullable = true)
704
+ // | |-- element: struct (containsNull = true)
705
+ // | | |-- int_field4: string (nullable = true)
706
+ assert(fields.head.dataType == IntegerType )
707
+ assert(fields(1 ).dataType == StringType )
708
+ assert(fields(2 ).dataType == IntegerType )
709
+ assert(fields(3 ).dataType.isInstanceOf [StructType ])
710
+ assert(fields(4 ).dataType.isInstanceOf [ArrayType ])
711
+ assert(fields(5 ).dataType.isInstanceOf [ArrayType ])
712
+
713
+ assert(fields(3 ).dataType.asInstanceOf [StructType ].fields.head.dataType == IntegerType )
714
+ assert(fields(4 ).dataType.asInstanceOf [ArrayType ].elementType == StringType )
715
+ assert(fields(5 ).dataType.asInstanceOf [ArrayType ].elementType.isInstanceOf [StructType ])
716
+ assert(fields(5 ).dataType.asInstanceOf [ArrayType ].elementType.asInstanceOf [StructType ].fields.head.dataType == IntegerType )
717
+
718
+ // Ensure metadata is copied
719
+ assert(fields.head.metadata.getString(" comment" ) == " Test1" )
720
+ assert(fields(1 ).metadata.getLong(" maxLength" ) == 120 )
721
+ assert(fields(3 ).dataType.asInstanceOf [StructType ].fields.head.metadata.getString(" comment" ) == " Test3" )
722
+ assert(fields(4 ).metadata.getLong(" maxLength" ) == 60 )
723
+ assert(fields(5 ).dataType.asInstanceOf [ArrayType ].elementType.asInstanceOf [StructType ].fields.head.metadata.getString(" comment" ) == " Test4" )
724
+ }
725
+
629
726
test(" copyMetadata should retain metadata on conflicts by default" ) {
630
727
val df1 = List (1 , 2 , 3 ).toDF(" col1" )
631
728
val df2 = List (1 , 2 , 3 ).toDF(" col1" )
0 commit comments