Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.apache.flinkx

import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}

package object api {

/** Basic type has an arity of 1. See [[BasicTypeInfo#getArity()]] */
private[api] val BasicTypeArity: Int = 1

/** Basic type has 1 field. See [[BasicTypeInfo#getTotalFields()]] */
private[api] val BasicTypeTotalFields: Int = 1

/** Documentation of [[TypeInformation#getTotalFields()]] states the total number of fields must be at least 1. */
private[api] val MinimumTotalFields: Int = 1

}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ class ArraySerializer[T](val child: TypeSerializer[T], clazz: Class[T]) extends
}
}

override def copy(source: DataInputView, target: DataOutputView): Unit = {
var remaining = source.readInt()
target.writeInt(remaining)
while (remaining > 0) {
child.copy(source, target)
remaining -= 1
}
}

override def snapshotConfiguration(): TypeSerializerSnapshot[Array[T]] =
new CollectionSerializerSnapshot[Array, T, ArraySerializer[T]](child, classOf[ArraySerializer[T]], clazz)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ class CoproductSerializer[T](subtypeClasses: Array[Class[_]], subtypeSerializers
subtype.asInstanceOf[TypeSerializer[T]].deserialize(source)
}

override def copy(source: DataInputView, target: DataOutputView): Unit = {
val index = source.readByte()
val subtype = subtypeSerializers(index.toInt)
target.writeByte(index)
subtype.asInstanceOf[TypeSerializer[T]].copy(source, target)
}

override def snapshotConfiguration(): TypeSerializerSnapshot[T] =
new CoproductSerializerSnapshot(subtypeClasses, subtypeSerializers)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ class ListCCSerializer[T](child: TypeSerializer[T], clazz: Class[T]) extends Mut
record.foreach(element => child.serialize(element, target))
}

override def copy(source: DataInputView, target: DataOutputView): Unit = {
var remaining = source.readInt()
target.writeInt(remaining)
while (remaining > 0) {
child.copy(source, target)
remaining -= 1
}
}

override def snapshotConfiguration(): TypeSerializerSnapshot[::[T]] =
new CollectionSerializerSnapshot[::, T, ListCCSerializer[T]](child, classOf[ListCCSerializer[T]], clazz)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,30 @@ class ListSerializer[T](child: TypeSerializer[T], clazz: Class[T]) extends Mutab
override def createInstance(): List[T] = List.empty[T]
override def getLength: Int = -1
override def deserialize(source: DataInputView): List[T] = {
val count = source.readInt()
val result = for {
_ <- 0 until count
} yield {
child.deserialize(source)
var remaining = source.readInt()
val builder = List.newBuilder[T]
builder.sizeHint(remaining)
while (remaining > 0) {
builder.addOne(child.deserialize(source))
remaining -= 1
}
result.toList
builder.result()
}

override def serialize(record: List[T], target: DataOutputView): Unit = {
target.writeInt(record.size)
record.foreach(element => child.serialize(element, target))
}

override def copy(source: DataInputView, target: DataOutputView): Unit = {
var remaining = source.readInt()
target.writeInt(remaining)
while (remaining > 0) {
child.copy(source, target)
remaining -= 1
}
}

override def snapshotConfiguration(): TypeSerializerSnapshot[List[T]] =
new CollectionSerializerSnapshot[List, T, ListSerializer[T]](child, classOf[ListSerializer[T]], clazz)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ class MapSerializer[K, V](ks: TypeSerializer[K], vs: TypeSerializer[V]) extends
override def createInstance(): Map[K, V] = Map.empty[K, V]
override def getLength: Int = -1
override def deserialize(source: DataInputView): Map[K, V] = {
val count = source.readInt()
val result = for {
_ <- 0 until count
} yield {
var remaining = source.readInt()
val builder = Map.newBuilder[K, V]
builder.sizeHint(remaining)
while (remaining > 0) {
val key = ks.deserialize(source)
val value = vs.deserialize(source)
key -> value
builder.addOne(key -> value)
remaining -= 1
}
result.toMap
builder.result()
}
override def serialize(record: Map[K, V], target: DataOutputView): Unit = {
target.writeInt(record.size)
Expand All @@ -48,6 +49,16 @@ class MapSerializer[K, V](ks: TypeSerializer[K], vs: TypeSerializer[V]) extends
})
}

override def copy(source: DataInputView, target: DataOutputView): Unit = {
var remaining = source.readInt()
target.writeInt(remaining)
while (remaining > 0) {
ks.copy(source, target)
vs.copy(source, target)
remaining -= 1
}
}

override def snapshotConfiguration(): TypeSerializerSnapshot[Map[K, V]] = new MapSerializerSnapshot(ks, vs)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ case class MappedSerializer[A, B](mapper: TypeMapper[A, B], ser: TypeSerializer[

override def deserialize(source: DataInputView): A = mapper.contramap(ser.deserialize(source))

override def copy(source: DataInputView, target: DataOutputView): Unit = ser.copy(source, target)

override def snapshotConfiguration(): TypeSerializerSnapshot[A] = new MappedSerializerSnapshot[A, B](mapper, ser)

override def createInstance(): A = mapper.contramap(ser.createInstance())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ package org.apache.flinkx.api.serializer
import org.apache.flink.api.common.typeutils.{TypeSerializer, TypeSerializerSnapshot}
import org.apache.flink.core.memory.{DataInputView, DataOutputView}

import scala.collection.immutable
import scala.collection.immutable.ArraySeq
import scala.reflect.ClassTag

class SeqSerializer[T](child: TypeSerializer[T], clazz: Class[T]) extends MutableSerializer[Seq[T]] {

private implicit val classTag: ClassTag[T] = ClassTag(clazz)

override val isImmutableType: Boolean = child.isImmutableType

override def copy(from: Seq[T]): Seq[T] = {
Expand All @@ -27,19 +33,29 @@ class SeqSerializer[T](child: TypeSerializer[T], clazz: Class[T]) extends Mutabl
override def createInstance(): Seq[T] = Seq.empty[T]
override def getLength: Int = -1
override def deserialize(source: DataInputView): Seq[T] = {
val count = source.readInt()
val result = for {
_ <- 0 until count
} yield {
child.deserialize(source)
val length = source.readInt()
val array = new Array[T](length)
var i = 0
while (i < length) {
array(i) = child.deserialize(source)
i += 1
}
result
ArraySeq.unsafeWrapArray(array)
}
override def serialize(record: Seq[T], target: DataOutputView): Unit = {
target.writeInt(record.size)
record.foreach(element => child.serialize(element, target))
}

override def copy(source: DataInputView, target: DataOutputView): Unit = {
var remaining = source.readInt()
target.writeInt(remaining)
while (remaining > 0) {
child.copy(source, target)
remaining -= 1
}
}

override def snapshotConfiguration(): TypeSerializerSnapshot[Seq[T]] =
new CollectionSerializerSnapshot[Seq, T, SeqSerializer[T]](child, classOf[SeqSerializer[T]], clazz)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,29 @@ class SetSerializer[T](child: TypeSerializer[T], clazz: Class[T]) extends Mutabl
override def createInstance(): Set[T] = Set.empty[T]
override def getLength: Int = -1
override def deserialize(source: DataInputView): Set[T] = {
val count = source.readInt()
val result = for {
_ <- 0 until count
} yield {
child.deserialize(source)
var remaining = source.readInt()
val builder = Set.newBuilder[T]
builder.sizeHint(remaining)
while (remaining > 0) {
builder.addOne(child.deserialize(source))
remaining -= 1
}
result.toSet
builder.result()
}
override def serialize(record: Set[T], target: DataOutputView): Unit = {
target.writeInt(record.size)
record.foreach(element => child.serialize(element, target))
}

override def copy(source: DataInputView, target: DataOutputView): Unit = {
var remaining = source.readInt()
target.writeInt(remaining)
while (remaining > 0) {
child.copy(source, target)
remaining -= 1
}
}

override def snapshotConfiguration(): TypeSerializerSnapshot[Set[T]] =
new CollectionSerializerSnapshot[Set, T, SetSerializer[T]](child, classOf[SetSerializer[T]], clazz)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,29 @@ class VectorSerializer[T](child: TypeSerializer[T], clazz: Class[T]) extends Mut
override def createInstance(): Vector[T] = Vector.empty[T]
override def getLength: Int = -1
override def deserialize(source: DataInputView): Vector[T] = {
val count = source.readInt()
val result = for {
_ <- 0 until count
} yield {
child.deserialize(source)
var remaining = source.readInt()
val builder = Vector.newBuilder[T]
builder.sizeHint(remaining)
while (remaining > 0) {
builder.addOne(child.deserialize(source))
remaining -= 1
}
result.toVector
builder.result()
}
override def serialize(record: Vector[T], target: DataOutputView): Unit = {
target.writeInt(record.size)
record.foreach(element => child.serialize(element, target))
}

override def copy(source: DataInputView, target: DataOutputView): Unit = {
var remaining = source.readInt()
target.writeInt(remaining)
while (remaining > 0) {
child.copy(source, target)
remaining -= 1
}
}

override def snapshotConfiguration(): TypeSerializerSnapshot[Vector[T]] =
new CollectionSerializerSnapshot[Vector, T, VectorSerializer[T]](child, classOf[VectorSerializer[T]], clazz)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.flink.api.common.typeutils.CompositeType.{
}
import org.apache.flink.api.common.typeutils._
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase
import org.apache.flinkx.api.MinimumTotalFields

import java.util.regex.{Matcher, Pattern}
import scala.annotation.tailrec
Expand Down Expand Up @@ -63,6 +64,9 @@ class CaseClassTypeInfo[T <: Product](
Pattern.compile(REGEX_NESTED_FIELDS_WILDCARD)
private val PATTERN_INT_FIELD: Pattern = Pattern.compile(REGEX_INT_FIELD)

override def getTotalFields: Int =
if (super.getTotalFields == 0) MinimumTotalFields else super.getTotalFields

@PublicEvolving
def getFieldIndices(fields: Array[String]): Array[Int] = {
fields map { x => fieldNames.indexOf(x) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.api.common.serialization.SerializerConfig
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flinkx.api.MinimumTotalFields

import scala.reflect.{ClassTag, classTag}

Expand All @@ -17,7 +18,7 @@ case class CollectionTypeInformation[T: ClassTag](serializer: TypeSerializer[T])
override def isBasicType: Boolean = false
override def isTupleType: Boolean = false
override def isKeyType: Boolean = false
override def getTotalFields: Int = 1
override def getTotalFields: Int = MinimumTotalFields
override def getTypeClass: Class[T] = clazz
override def getArity: Int = 1
override def getArity: Int = 0
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.api.common.serialization.SerializerConfig
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flinkx.api.MinimumTotalFields

case class CoproductTypeInformation[T](c: Class[T], ser: TypeSerializer[T]) extends TypeInformation[T] {

Expand All @@ -14,7 +15,7 @@ case class CoproductTypeInformation[T](c: Class[T], ser: TypeSerializer[T]) exte
override def isBasicType: Boolean = false
override def isTupleType: Boolean = false
override def isKeyType: Boolean = false
override def getTotalFields: Int = 1
override def getTotalFields: Int = MinimumTotalFields
override def getTypeClass: Class[T] = c
override def getArity: Int = 1
override def getArity: Int = 0
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import org.apache.flink.api.common.serialization.SerializerConfig
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flinkx.api.serializer.UnitSerializer
import org.apache.flinkx.api.{BasicTypeArity, BasicTypeTotalFields}

class UnitTypeInformation extends TypeInformation[Unit] {

Expand All @@ -13,11 +14,11 @@ class UnitTypeInformation extends TypeInformation[Unit] {
def createSerializer(config: ExecutionConfig): TypeSerializer[Unit] = new UnitSerializer()

override def isKeyType: Boolean = false
override def getTotalFields: Int = 0
override def getTotalFields: Int = BasicTypeTotalFields
override def isTupleType: Boolean = false
override def canEqual(obj: Any): Boolean = obj.isInstanceOf[UnitTypeInformation]
override def getTypeClass: Class[Unit] = classOf[Unit]
override def getArity: Int = 0
override def getArity: Int = BasicTypeArity
override def isBasicType: Boolean = true

override def toString: String = "{}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.apache.flinkx.api

import cats.data.NonEmptyList
import org.apache.flink.api.common.serialization.SerializerConfigImpl
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flinkx.api.AnyTest.FAny.FValueAny.FTerm
import org.apache.flinkx.api.AnyTest.FAny.FValueAny.FTerm.StringTerm
Expand All @@ -11,29 +10,21 @@ import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class AnyTest extends AnyFlatSpec with Matchers with TestUtils {
val ec = new SerializerConfigImpl()

def createSerializer[T: TypeInformation] =
implicitly[TypeInformation[T]].createSerializer(ec)

it should "serialize concrete class" in {
val ser = createSerializer[StringTerm]
roundtrip(ser, StringTerm("fo"))
testTypeInfoAndSerializer(StringTerm("fo"))
}

it should "serialize ADT" in {
val ser = createSerializer[FAny]
roundtrip(ser, StringTerm("fo"))
testTypeInfoAndSerializer(StringTerm("fo"), nullable = false)
}

it should "serialize NEL" in {
val ser = createSerializer[NonEmptyList[FTerm]]
roundtrip(ser, NonEmptyList.one(StringTerm("fo")))
testTypeInfoAndSerializer(NonEmptyList.one(StringTerm("fo")))
}

it should "serialize nested nel" in {
val ser = createSerializer[TermFilter]
roundtrip(ser, TermFilter("a", NonEmptyList.one(StringTerm("fo"))))
testTypeInfoAndSerializer(TermFilter("a", NonEmptyList.one(StringTerm("fo"))))
}
}

Expand Down
Loading
Loading