Skip to content

Commit a68e18a

Browse files
authored
Merge pull request #1050 from Kotlin/plugin-group-by
Enhanced GroupBy support in compiler plugin
2 parents 05d911b + 16f5d51 commit a68e18a

File tree

11 files changed

+118
-9
lines changed

11 files changed

+118
-9
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ public fun <T> DataFrame<T>.add(body: AddDsl<T>.() -> Unit): DataFrame<T> {
248248
return dataFrameOf(this@add.columns() + dsl.columns).cast()
249249
}
250250

251+
@Refine
252+
@Interpretable("GroupByAdd")
251253
public inline fun <reified R, T, G> GroupBy<T, G>.add(
252254
name: String,
253255
infer: Infer = Infer.Nulls,

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ class FunctionCallTransformer(
239239
val groupMarker = rootMarkers[1]
240240

241241
val (keySchema, groupSchema) = if (groupBy != null) {
242-
val keySchema = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop)
243-
val groupSchema = PluginDataFrameSchema(groupBy.df.columns())
242+
val keySchema = groupBy.keys
243+
val groupSchema = groupBy.groups
244244
keySchema to groupSchema
245245
} else {
246246
PluginDataFrameSchema.EMPTY to PluginDataFrameSchema.EMPTY

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/ExpectedArgumentDelegates.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.plugin.impl
22

33
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter.*
4+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
45
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation
56
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.DataFrameCallableId
67
import kotlin.properties.PropertyDelegateProvider
@@ -35,3 +36,7 @@ internal fun <T> AbstractInterpreter<T>.ignore(
3536
): ExpectedArgumentProvider<Nothing?> =
3637
arg(name, lens = Interpreter.Id, defaultValue = Present(null))
3738

39+
internal fun <T> AbstractInterpreter<T>.groupBy(
40+
name: ArgumentName? = null
41+
): ExpectedArgumentProvider<GroupBy> = arg(name, lens = Interpreter.GroupBy)
42+

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/Interpreter.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ interface Interpreter<T> {
2626

2727
data object Schema : Lens
2828

29+
data object GroupBy : Lens
30+
2931
data object Id : Lens
3032

3133
// required to compute whether resulting schema should be inheritor of previous class or a new class

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/SimpleCol.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ data class PluginDataFrameSchema(
2727
}
2828
}
2929

30+
fun PluginDataFrameSchema.add(name: String, type: ConeKotlinType, context: KotlinTypeFacade): PluginDataFrameSchema {
31+
return PluginDataFrameSchema(columns() + context.simpleColumnOf(name, type))
32+
}
33+
3034
private fun List<SimpleCol>.asString(indent: String = ""): String {
3135
return joinToString("\n") {
3236
val col = when (it) {

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,22 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
1919
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
2020
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
2121
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
22+
import org.jetbrains.kotlinx.dataframe.plugin.impl.add
2223
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
2324
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
25+
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
2426
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
27+
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
2528

26-
class GroupBy(val df: PluginDataFrameSchema, val keys: List<ColumnWithPathApproximation>, val moveToTop: Boolean)
29+
class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema)
2730

2831
class DataFrameGroupBy : AbstractInterpreter<GroupBy>() {
2932
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
3033
val Arguments.moveToTop: Boolean by arg(defaultValue = Present(true))
3134
val Arguments.cols: ColumnsResolver by arg()
3235

3336
override fun Arguments.interpret(): GroupBy {
34-
return GroupBy(receiver, cols.resolve(receiver), moveToTop)
37+
return GroupBy(keys = createPluginDataFrameSchema(cols.resolve(receiver), moveToTop), groups = receiver)
3538
}
3639
}
3740

@@ -52,7 +55,7 @@ class GroupByInto : AbstractInterpreter<Unit>() {
5255
}
5356

5457
class Aggregate : AbstractSchemaModificationInterpreter() {
55-
val Arguments.receiver: GroupBy by arg()
58+
val Arguments.receiver: GroupBy by groupBy()
5659
val Arguments.body: FirAnonymousFunctionExpression by arg(lens = Interpreter.Id)
5760
override fun Arguments.interpret(): PluginDataFrameSchema {
5861
return aggregate(
@@ -87,7 +90,7 @@ fun KotlinTypeFacade.aggregate(
8790
)
8891
}
8992

90-
val cols = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop).columns() + dsl.columns.map {
93+
val cols = groupBy.keys.columns() + dsl.columns.map {
9194
simpleColumnOf(it.name, it.type)
9295
}
9396
PluginDataFrameSchema(cols)
@@ -144,13 +147,23 @@ fun KotlinTypeFacade.createPluginDataFrameSchema(keys: List<ColumnWithPathApprox
144147
}
145148

146149
class GroupByToDataFrame : AbstractSchemaModificationInterpreter() {
147-
val Arguments.receiver: GroupBy by arg()
150+
val Arguments.receiver: GroupBy by groupBy()
148151
val Arguments.groupedColumnName: String? by arg(defaultValue = Present(null))
149152

150153
override fun Arguments.interpret(): PluginDataFrameSchema {
151-
val grouped = listOf(SimpleFrameColumn(groupedColumnName ?: "group", receiver.df.columns()))
154+
val grouped = listOf(SimpleFrameColumn(groupedColumnName ?: "group", receiver.groups.columns()))
152155
return PluginDataFrameSchema(
153-
createPluginDataFrameSchema(receiver.keys, receiver.moveToTop).columns() + grouped
156+
receiver.keys.columns() + grouped
154157
)
155158
}
156159
}
160+
161+
class GroupByAdd : AbstractInterpreter<GroupBy>() {
162+
val Arguments.receiver: GroupBy by groupBy()
163+
val Arguments.name: String by arg()
164+
val Arguments.type: TypeApproximation by type(name("expression"))
165+
166+
override fun Arguments.interpret(): GroupBy {
167+
return GroupBy(receiver.keys, receiver.groups.add(name, type.type, context = this))
168+
}
169+
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.jetbrains.kotlin.fir.references.resolved
3636
import org.jetbrains.kotlin.fir.references.symbol
3737
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
3838
import org.jetbrains.kotlin.fir.resolve.fqName
39+
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
3940
import org.jetbrains.kotlin.fir.scopes.collectAllProperties
4041
import org.jetbrains.kotlin.fir.scopes.getProperties
4142
import org.jetbrains.kotlin.fir.scopes.impl.declaredMemberScope
@@ -78,6 +79,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
7879
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
7980
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
8081
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnsResolver
82+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
8183
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.SingleColumnApproximation
8284
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation
8385

@@ -277,6 +279,17 @@ fun <T> KotlinTypeFacade.interpret(
277279
}
278280
}
279281

282+
is Interpreter.GroupBy -> {
283+
assert(expectedReturnType.toString() == GroupBy::class.qualifiedName!!) {
284+
"'$name' should be ${GroupBy::class.qualifiedName!!}, but plugin expect $expectedReturnType"
285+
}
286+
287+
val resolvedType = it.expression.resolvedType.fullyExpandedType(session)
288+
val keys = pluginDataFrameSchema(resolvedType.typeArguments[0])
289+
val groups = pluginDataFrameSchema(resolvedType.typeArguments[1])
290+
Interpreter.Success(GroupBy(keys, groups))
291+
}
292+
280293
is Interpreter.Id -> {
281294
Interpreter.Success(it.expression)
282295
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0
8888
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0
8989
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault
9090
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
91+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd
9192
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
9293
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Move0
9394
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MoveAfter0
@@ -275,6 +276,7 @@ internal inline fun <reified T> String.load(): T {
275276
"MoveToLeft1" -> MoveToLeft1()
276277
"MoveToRight0" -> MoveToRight0()
277278
"MoveAfter0" -> MoveAfter0()
279+
"GroupByAdd" -> GroupByAdd()
278280
else -> error("$this")
279281
} as T
280282
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.api.groupBy
5+
import org.jetbrains.kotlinx.dataframe.io.*
6+
7+
enum class State {
8+
Idle,
9+
Productive,
10+
Maintenance,
11+
}
12+
13+
class Event(val toolId: String, val state: State, val timestamp: Long)
14+
15+
fun box(): String {
16+
val tool1 = "tool_1"
17+
val tool2 = "tool_2"
18+
val tool3 = "tool_3"
19+
20+
val events = listOf(
21+
Event(tool1, State.Idle, 0),
22+
Event(tool1, State.Productive, 5),
23+
Event(tool2, State.Idle, 0),
24+
Event(tool2, State.Maintenance, 10),
25+
Event(tool2, State.Idle, 20),
26+
Event(tool3, State.Idle, 0),
27+
Event(tool3, State.Productive, 25),
28+
).toDataFrame()
29+
30+
val lastTimestamp = events.maxOf { timestamp }
31+
val groupBy = events
32+
.groupBy { toolId }
33+
.sortBy { timestamp }
34+
.add("stateDuration") {
35+
(next()?.timestamp ?: lastTimestamp) - timestamp
36+
}.toDataFrame()
37+
38+
groupBy.group[0].stateDuration
39+
40+
groupBy.compareSchemas(strict = true)
41+
return "OK"
42+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
fun box(): String {
7+
val df = dataFrameOf("a", "b", "c")(1, 2, 3)
8+
9+
val groupBy = df.groupBy { a }
10+
11+
val df1 = groupBy.updateGroups { it.remove { a } }.toDataFrame()
12+
df1.compileTimeSchema().print()
13+
return "OK"
14+
}

0 commit comments

Comments
 (0)