Skip to content

Commit 0835540

Browse files
committed
RD-14980: Support for INSERT, UPDATE, DELETE
1 parent fd270d1 commit 0835540

File tree

3 files changed

+146
-19
lines changed

3 files changed

+146
-19
lines changed

src/main/scala/com/rawlabs/das/databricks/DASDatabricks.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ package com.rawlabs.das.databricks
1414

1515
import com.databricks.sdk.WorkspaceClient
1616
import com.databricks.sdk.core.DatabricksConfig
17-
import com.databricks.sdk.service.catalog.ListTablesRequest
17+
import com.databricks.sdk.service.catalog.{GetTableRequest, ListTablesRequest}
1818
import com.databricks.sdk.service.sql.ListWarehousesRequest
1919
import com.rawlabs.das.sdk.{DASFunction, DASSdk, DASTable}
2020
import com.rawlabs.protocol.das.{FunctionDefinition, TableDefinition}
@@ -41,7 +41,15 @@ class DASDatabricks(options: Map[String, String]) extends DASSdk {
4141
val databricksTables = databricksClient.tables().list(req)
4242
val tables = mutable.Map.empty[String, DASDatabricksTable]
4343
databricksTables.forEach { databricksTable =>
44-
tables.put(databricksTable.getName, new DASDatabricksTable(databricksClient, warehouse, databricksTable))
44+
// `databricksTable` is a `TableInfo` and its `getTableConstraints` permits us to know
45+
// if it has a primary key column, which we could use for UPDATE calls. But it's not populated.
46+
// We have to issue an individual `GetTableRequest` call (the single table one, that returns the same
47+
// object but with constraints provided).
48+
val tableDetails = {
49+
val tableReq = new GetTableRequest().setFullName(catalog + '.' + schema + '.' + databricksTable.getName)
50+
databricksClient.tables().get(tableReq)
51+
}
52+
tables.put(databricksTable.getName, new DASDatabricksTable(databricksClient, warehouse, tableDetails))
4553
}
4654
tables.toMap
4755
}

src/main/scala/com/rawlabs/das/databricks/DASDatabricksTable.scala

Lines changed: 136 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,21 @@ package com.rawlabs.das.databricks
1515
import com.databricks.sdk.WorkspaceClient
1616
import com.databricks.sdk.service.catalog.{ColumnInfo, ColumnTypeName, TableInfo}
1717
import com.databricks.sdk.service.sql._
18-
import com.rawlabs.das.sdk.{DASExecuteResult, DASTable}
18+
import com.rawlabs.das.sdk.{DASExecuteResult, DASSdkException, DASTable}
1919
import com.rawlabs.protocol.das._
2020
import com.rawlabs.protocol.raw.{Type, Value}
2121
import com.typesafe.scalalogging.StrictLogging
2222

2323
import scala.annotation.tailrec
2424
import scala.collection.JavaConverters.collectionAsScalaIterableConverter
25+
import scala.collection.mutable
2526

2627
class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databricksTable: TableInfo)
2728
extends DASTable
2829
with StrictLogging {
2930

31+
private val tableFullName = databricksTable.getSchemaName + '.' + databricksTable.getName
32+
3033
override def getRelSize(quals: Seq[Qual], columns: Seq[String]): (Int, Int) = REL_SIZE
3134

3235
override def execute(
@@ -36,8 +39,7 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
3639
maybeLimit: Option[Long]
3740
): DASExecuteResult = {
3841
val databricksColumns = if (columns.isEmpty) Seq("NULL") else columns.map(databricksColumnName)
39-
var query =
40-
s"SELECT ${databricksColumns.mkString(",")} FROM " + databricksTable.getSchemaName + '.' + databricksTable.getName
42+
var query = s"SELECT ${databricksColumns.mkString(",")} FROM " + tableFullName
4143
val stmt = new ExecuteStatementRequest()
4244
val parameters = new java.util.LinkedList[StatementParameterListItem]
4345
if (quals.nonEmpty) {
@@ -93,9 +95,11 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
9395

9496
stmt.setStatement(query).setWarehouseId(warehouseID).setDisposition(Disposition.INLINE).setFormat(Format.JSON_ARRAY)
9597
val executeAPI = client.statementExecution()
96-
val response1 = executeAPI.executeStatement(stmt)
97-
val response = getResult(response1)
98-
new DASDatabricksExecuteResult(executeAPI, response)
98+
val response = executeAPI.executeStatement(stmt)
99+
getResult(response) match {
100+
case Left(error) => throw new DASSdkException(error)
101+
case Right(result) => new DASDatabricksExecuteResult(executeAPI, result)
102+
}
99103
}
100104

101105
private def databricksColumnName(name: String): String = {
@@ -119,21 +123,19 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
119123
override def canSort(sortKeys: Seq[SortKey]): Seq[SortKey] = sortKeys
120124

121125
@tailrec
122-
private def getResult(response: StatementResponse): StatementResponse = {
126+
private def getResult(response: StatementResponse): Either[String, StatementResponse] = {
123127
val state = response.getStatus.getState
124128
logger.info(s"Query ${response.getStatementId} state: $state")
125129
state match {
126130
case StatementState.PENDING | StatementState.RUNNING =>
131+
logger.info(s"Query is still running, polling again in $POLLING_TIME ms")
127132
Thread.sleep(POLLING_TIME)
128133
val response2 = client.statementExecution().getStatement(response.getStatementId)
129134
getResult(response2)
130-
case StatementState.SUCCEEDED => response
131-
case StatementState.FAILED =>
132-
throw new RuntimeException(s"Query failed: ${response.getStatus.getError.getMessage}")
133-
case StatementState.CLOSED =>
134-
throw new RuntimeException(s"Query closed: ${response.getStatus.getError.getMessage}")
135-
case StatementState.CANCELED =>
136-
throw new RuntimeException(s"Query canceled: ${response.getStatus.getError.getMessage}")
135+
case StatementState.SUCCEEDED => Right(response)
136+
case StatementState.FAILED => Left(s"Query failed: ${response.getStatus.getError.getMessage}")
137+
case StatementState.CLOSED => Left(s"Query closed: ${response.getStatus.getError.getMessage}")
138+
case StatementState.CANCELED => Left(s"Query canceled: ${response.getStatus.getError.getMessage}")
137139
}
138140
}
139141

@@ -161,6 +163,26 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
161163
definition.build()
162164
}
163165

166+
// Potential primary key column name found in constraints table metadata.
167+
private var primaryKeyColumn: Option[String] = None
168+
169+
// Try to find a primary key constraint over one column.
170+
if (databricksTable.getTableConstraints == null) {
171+
logger.warn(s"No constraints found for table $tableFullName")
172+
} else {
173+
databricksTable.getTableConstraints.forEach { constraint =>
174+
val primaryKeyConstraint = constraint.getPrimaryKeyConstraint
175+
if (primaryKeyConstraint != null) {
176+
if (primaryKeyConstraint.getChildColumns.size != 1) {
177+
logger.warn("Ignoring composite primary key")
178+
} else {
179+
primaryKeyColumn = Some(primaryKeyConstraint.getChildColumns.iterator().next())
180+
logger.info(s"Found primary key ($primaryKeyColumn)")
181+
}
182+
}
183+
}
184+
}
185+
164186
private def columnType(info: ColumnInfo): Option[Type] = {
165187
val builder = Type.newBuilder()
166188
val columnType = info.getTypeName
@@ -230,6 +252,11 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
230252
}
231253
}
232254

255+
override def uniqueColumn: String = {
256+
// Return the first column if none.
257+
primaryKeyColumn.getOrElse(databricksTable.getColumns.asScala.head.getName)
258+
}
259+
233260
private def rawValueToParameter(v: Value): StatementParameterListItem = {
234261
logger.debug(s"Converting value to parameter: $v")
235262
val parameter = new StatementParameterListItem()
@@ -286,4 +313,99 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
286313
}
287314
}
288315

316+
override def insert(row: Row): Row = {
317+
bulkInsert(Seq(row)).head
318+
}
319+
320+
// INSERTs can be done in batches, but by inlining values in the query string.
321+
// We don't want to send gigantic query strings accidentally. We try to keep
322+
// queries around that size.
323+
private val MAX_INSERT_CODE_SIZE = 2048
324+
325+
override def bulkInsert(rows: Seq[Row]): Seq[Row] = {
326+
// There's no bulk call in Databricks, we inline values. We build a
327+
// batches of query strings that are at most of MAX_INSERT_CODE_SIZE and
328+
// loop until all rows are consumed.
329+
val columnNames = databricksTable.getColumns.asScala.map(_.getName)
330+
val values = rows.map { row =>
331+
val data = row.getDataMap
332+
columnNames
333+
.map { name =>
334+
val value = data.get(name)
335+
if (value == null) {
336+
"DEFAULT"
337+
} else {
338+
rawValueToDatabricksQueryString(value)
339+
}
340+
}
341+
.mkString("(", ",", ")")
342+
}
343+
val stmt = new ExecuteStatementRequest()
344+
.setWarehouseId(warehouseID)
345+
.setDisposition(Disposition.INLINE)
346+
.setFormat(Format.JSON_ARRAY)
347+
348+
val items = values.iterator
349+
while (items.nonEmpty) {
350+
val item = items.next()
351+
val code = StringBuilder.newBuilder
352+
code.append(s"INSERT INTO ${databricksTable.getName} VALUES $item")
353+
while (code.size < MAX_INSERT_CODE_SIZE && items.hasNext) {
354+
code.append(s",${items.next()}")
355+
}
356+
stmt.setStatement(code.toString())
357+
val executeAPI = client.statementExecution()
358+
val response = executeAPI.executeStatement(stmt)
359+
getResult(response).left.foreach(error => throw new RuntimeException(error))
360+
}
361+
rows
362+
}
363+
364+
override def delete(rowId: Value): Unit = {
365+
if (primaryKeyColumn.isEmpty) {
366+
throw new IllegalArgumentException(s"Table $tableFullName has no primary key column")
367+
}
368+
val stmt = new ExecuteStatementRequest()
369+
.setWarehouseId(warehouseID)
370+
.setDisposition(Disposition.INLINE)
371+
.setFormat(Format.JSON_ARRAY)
372+
stmt.setStatement(
373+
s"DELETE FROM ${databricksTable.getName} WHERE ${databricksColumnName(uniqueColumn)} = ${rawValueToDatabricksQueryString(rowId)}"
374+
)
375+
val executeAPI = client.statementExecution()
376+
val response = executeAPI.executeStatement(stmt)
377+
getResult(response).left.foreach(error => throw new RuntimeException(error))
378+
}
379+
380+
// How many rows are accepted in a batch update. Technically we're unlimited
381+
// since updates are sent one by one.
382+
private val MODIFY_BATCH_SIZE = 1000
383+
384+
override def modifyBatchSize: Int = {
385+
MODIFY_BATCH_SIZE
386+
}
387+
388+
override def update(rowId: Value, newValues: Row): Row = {
389+
if (primaryKeyColumn.isEmpty) {
390+
throw new IllegalArgumentException(s"Table $tableFullName has no primary key column")
391+
}
392+
val buffer = mutable.Buffer.empty[String]
393+
newValues.getDataMap
394+
.forEach {
395+
case (name, value) =>
396+
buffer.append(s"${databricksColumnName(name)} = ${rawValueToDatabricksQueryString(value)}")
397+
}
398+
val setValues = buffer.mkString(", ")
399+
val stmt = new ExecuteStatementRequest()
400+
.setWarehouseId(warehouseID)
401+
.setDisposition(Disposition.INLINE)
402+
.setFormat(Format.JSON_ARRAY)
403+
stmt.setStatement(
404+
s"UPDATE ${databricksTable.getName} SET $setValues WHERE ${databricksColumnName(uniqueColumn)} = ${rawValueToDatabricksQueryString(rowId)}"
405+
)
406+
val executeAPI = client.statementExecution()
407+
val response = executeAPI.executeStatement(stmt)
408+
getResult(response).left.foreach(error => throw new RuntimeException(error))
409+
newValues
410+
}
289411
}

src/main/scala/com/rawlabs/das/databricks/id.kt

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)