@@ -15,18 +15,21 @@ package com.rawlabs.das.databricks
15
15
import com .databricks .sdk .WorkspaceClient
16
16
import com .databricks .sdk .service .catalog .{ColumnInfo , ColumnTypeName , TableInfo }
17
17
import com .databricks .sdk .service .sql ._
18
- import com .rawlabs .das .sdk .{DASExecuteResult , DASTable }
18
+ import com .rawlabs .das .sdk .{DASExecuteResult , DASSdkException , DASTable }
19
19
import com .rawlabs .protocol .das ._
20
20
import com .rawlabs .protocol .raw .{Type , Value }
21
21
import com .typesafe .scalalogging .StrictLogging
22
22
23
23
import scala .annotation .tailrec
24
24
import scala .collection .JavaConverters .collectionAsScalaIterableConverter
25
+ import scala .collection .mutable
25
26
26
27
class DASDatabricksTable (client : WorkspaceClient , warehouseID : String , databricksTable : TableInfo )
27
28
extends DASTable
28
29
with StrictLogging {
29
30
31
+ private val tableFullName = databricksTable.getSchemaName + '.' + databricksTable.getName
32
+
30
33
override def getRelSize (quals : Seq [Qual ], columns : Seq [String ]): (Int , Int ) = REL_SIZE
31
34
32
35
override def execute (
@@ -36,8 +39,7 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
36
39
maybeLimit : Option [Long ]
37
40
): DASExecuteResult = {
38
41
val databricksColumns = if (columns.isEmpty) Seq (" NULL" ) else columns.map(databricksColumnName)
39
- var query =
40
- s " SELECT ${databricksColumns.mkString(" ," )} FROM " + databricksTable.getSchemaName + '.' + databricksTable.getName
42
+ var query = s " SELECT ${databricksColumns.mkString(" ," )} FROM " + tableFullName
41
43
val stmt = new ExecuteStatementRequest ()
42
44
val parameters = new java.util.LinkedList [StatementParameterListItem ]
43
45
if (quals.nonEmpty) {
@@ -93,9 +95,11 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
93
95
94
96
stmt.setStatement(query).setWarehouseId(warehouseID).setDisposition(Disposition .INLINE ).setFormat(Format .JSON_ARRAY )
95
97
val executeAPI = client.statementExecution()
96
- val response1 = executeAPI.executeStatement(stmt)
97
- val response = getResult(response1)
98
- new DASDatabricksExecuteResult (executeAPI, response)
98
+ val response = executeAPI.executeStatement(stmt)
99
+ getResult(response) match {
100
+ case Left (error) => throw new DASSdkException (error)
101
+ case Right (result) => new DASDatabricksExecuteResult (executeAPI, result)
102
+ }
99
103
}
100
104
101
105
private def databricksColumnName (name : String ): String = {
@@ -119,21 +123,19 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
119
123
override def canSort (sortKeys : Seq [SortKey ]): Seq [SortKey ] = sortKeys
120
124
121
125
@ tailrec
122
- private def getResult (response : StatementResponse ): StatementResponse = {
126
+ private def getResult (response : StatementResponse ): Either [ String , StatementResponse ] = {
123
127
val state = response.getStatus.getState
124
128
logger.info(s " Query ${response.getStatementId} state: $state" )
125
129
state match {
126
130
case StatementState .PENDING | StatementState .RUNNING =>
131
+ logger.info(s " Query is still running, polling again in $POLLING_TIME ms " )
127
132
Thread .sleep(POLLING_TIME )
128
133
val response2 = client.statementExecution().getStatement(response.getStatementId)
129
134
getResult(response2)
130
- case StatementState .SUCCEEDED => response
131
- case StatementState .FAILED =>
132
- throw new RuntimeException (s " Query failed: ${response.getStatus.getError.getMessage}" )
133
- case StatementState .CLOSED =>
134
- throw new RuntimeException (s " Query closed: ${response.getStatus.getError.getMessage}" )
135
- case StatementState .CANCELED =>
136
- throw new RuntimeException (s " Query canceled: ${response.getStatus.getError.getMessage}" )
135
+ case StatementState .SUCCEEDED => Right (response)
136
+ case StatementState .FAILED => Left (s " Query failed: ${response.getStatus.getError.getMessage}" )
137
+ case StatementState .CLOSED => Left (s " Query closed: ${response.getStatus.getError.getMessage}" )
138
+ case StatementState .CANCELED => Left (s " Query canceled: ${response.getStatus.getError.getMessage}" )
137
139
}
138
140
}
139
141
@@ -161,6 +163,26 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
161
163
definition.build()
162
164
}
163
165
166
+ // Potential primary key column name found in constraints table metadata.
167
+ private var primaryKeyColumn : Option [String ] = None
168
+
169
+ // Try to find a primary key constraint over one column.
170
+ if (databricksTable.getTableConstraints == null ) {
171
+ logger.warn(s " No constraints found for table $tableFullName" )
172
+ } else {
173
+ databricksTable.getTableConstraints.forEach { constraint =>
174
+ val primaryKeyConstraint = constraint.getPrimaryKeyConstraint
175
+ if (primaryKeyConstraint != null ) {
176
+ if (primaryKeyConstraint.getChildColumns.size != 1 ) {
177
+ logger.warn(" Ignoring composite primary key" )
178
+ } else {
179
+ primaryKeyColumn = Some (primaryKeyConstraint.getChildColumns.iterator().next())
180
+ logger.info(s " Found primary key ( $primaryKeyColumn) " )
181
+ }
182
+ }
183
+ }
184
+ }
185
+
164
186
private def columnType (info : ColumnInfo ): Option [Type ] = {
165
187
val builder = Type .newBuilder()
166
188
val columnType = info.getTypeName
@@ -230,6 +252,11 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
230
252
}
231
253
}
232
254
255
+ override def uniqueColumn : String = {
256
+ // Return the first column if none.
257
+ primaryKeyColumn.getOrElse(databricksTable.getColumns.asScala.head.getName)
258
+ }
259
+
233
260
private def rawValueToParameter (v : Value ): StatementParameterListItem = {
234
261
logger.debug(s " Converting value to parameter: $v" )
235
262
val parameter = new StatementParameterListItem ()
@@ -286,4 +313,99 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
286
313
}
287
314
}
288
315
316
+ override def insert (row : Row ): Row = {
317
+ bulkInsert(Seq (row)).head
318
+ }
319
+
320
+ // INSERTs can be done in batches, but by inlining values in the query string.
321
+ // We don't want to send gigantic query strings accidentally. We try to keep
322
+ // queries around that size.
323
+ private val MAX_INSERT_CODE_SIZE = 2048
324
+
325
+ override def bulkInsert (rows : Seq [Row ]): Seq [Row ] = {
326
+ // There's no bulk call in Databricks, we inline values. We build a
327
+ // batches of query strings that are at most of MAX_INSERT_CODE_SIZE and
328
+ // loop until all rows are consumed.
329
+ val columnNames = databricksTable.getColumns.asScala.map(_.getName)
330
+ val values = rows.map { row =>
331
+ val data = row.getDataMap
332
+ columnNames
333
+ .map { name =>
334
+ val value = data.get(name)
335
+ if (value == null ) {
336
+ " DEFAULT"
337
+ } else {
338
+ rawValueToDatabricksQueryString(value)
339
+ }
340
+ }
341
+ .mkString(" (" , " ," , " )" )
342
+ }
343
+ val stmt = new ExecuteStatementRequest ()
344
+ .setWarehouseId(warehouseID)
345
+ .setDisposition(Disposition .INLINE )
346
+ .setFormat(Format .JSON_ARRAY )
347
+
348
+ val items = values.iterator
349
+ while (items.nonEmpty) {
350
+ val item = items.next()
351
+ val code = StringBuilder .newBuilder
352
+ code.append(s " INSERT INTO ${databricksTable.getName} VALUES $item" )
353
+ while (code.size < MAX_INSERT_CODE_SIZE && items.hasNext) {
354
+ code.append(s " , ${items.next()}" )
355
+ }
356
+ stmt.setStatement(code.toString())
357
+ val executeAPI = client.statementExecution()
358
+ val response = executeAPI.executeStatement(stmt)
359
+ getResult(response).left.foreach(error => throw new RuntimeException (error))
360
+ }
361
+ rows
362
+ }
363
+
364
+ override def delete (rowId : Value ): Unit = {
365
+ if (primaryKeyColumn.isEmpty) {
366
+ throw new IllegalArgumentException (s " Table $tableFullName has no primary key column " )
367
+ }
368
+ val stmt = new ExecuteStatementRequest ()
369
+ .setWarehouseId(warehouseID)
370
+ .setDisposition(Disposition .INLINE )
371
+ .setFormat(Format .JSON_ARRAY )
372
+ stmt.setStatement(
373
+ s " DELETE FROM ${databricksTable.getName} WHERE ${databricksColumnName(uniqueColumn)} = ${rawValueToDatabricksQueryString(rowId)}"
374
+ )
375
+ val executeAPI = client.statementExecution()
376
+ val response = executeAPI.executeStatement(stmt)
377
+ getResult(response).left.foreach(error => throw new RuntimeException (error))
378
+ }
379
+
380
+ // How many rows are accepted in a batch update. Technically we're unlimited
381
+ // since updates are sent one by one.
382
+ private val MODIFY_BATCH_SIZE = 1000
383
+
384
+ override def modifyBatchSize : Int = {
385
+ MODIFY_BATCH_SIZE
386
+ }
387
+
388
+ override def update (rowId : Value , newValues : Row ): Row = {
389
+ if (primaryKeyColumn.isEmpty) {
390
+ throw new IllegalArgumentException (s " Table $tableFullName has no primary key column " )
391
+ }
392
+ val buffer = mutable.Buffer .empty[String ]
393
+ newValues.getDataMap
394
+ .forEach {
395
+ case (name, value) =>
396
+ buffer.append(s " ${databricksColumnName(name)} = ${rawValueToDatabricksQueryString(value)}" )
397
+ }
398
+ val setValues = buffer.mkString(" , " )
399
+ val stmt = new ExecuteStatementRequest ()
400
+ .setWarehouseId(warehouseID)
401
+ .setDisposition(Disposition .INLINE )
402
+ .setFormat(Format .JSON_ARRAY )
403
+ stmt.setStatement(
404
+ s " UPDATE ${databricksTable.getName} SET $setValues WHERE ${databricksColumnName(uniqueColumn)} = ${rawValueToDatabricksQueryString(rowId)}"
405
+ )
406
+ val executeAPI = client.statementExecution()
407
+ val response = executeAPI.executeStatement(stmt)
408
+ getResult(response).left.foreach(error => throw new RuntimeException (error))
409
+ newValues
410
+ }
289
411
}
0 commit comments