Skip to content

Commit 4407524

Browse files
authored
feat: add pathlike object support for storage operations (#39)
- Updated storage mixin method signatures to accept Union[str, Path] - Updated storage backend protocol signatures for pathlike objects - Modified BigQuery driver to use datetime.now(timezone.utc) instead of deprecated utcnow() - Implemented BigQuery optimized import/export with automatic GCS staging for non-GCS paths - Replaced direct file operations with storage backend calls in: - SQLite driver's _bulk_load_file method - AIOSQLite driver's _bulk_load_file method - SQLFileLoader's _read_file_content method - Added comprehensive tests for pathlike object support in: - Storage mixins integration tests - FSSpec backend unit tests
1 parent 14126ba commit 4407524

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+3021
-1183
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.12.0"
20+
rev: "v0.12.1"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

docs/examples/litestar_duckllm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
"""Litestar DuckLLM
23
34
This example demonstrates how to use the Litestar framework with the DuckLLM extension.

docs/examples/litestar_multi_db.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
"""Litestar Multi DB
23
34
This example demonstrates how to use multiple databases in a Litestar application.

docs/examples/service_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
"""Example demonstrating the high-level service layer.
23
34
This example shows how to use the DatabaseService and AsyncDatabaseService

docs/examples/standalone_demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
# type: ignore
23
# /// script
34
# dependencies = [
45
# "sqlspec[duckdb,performance]",

sqlspec/adapters/aiosqlite/driver.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ async def _execute_script(
203203
return result
204204

205205
async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
206-
"""Database-specific bulk load implementation."""
207-
# TODO: convert this to use the storage backend. it has async support
206+
"""Database-specific bulk load implementation using storage backend."""
208207
if format != "csv":
209208
msg = f"aiosqlite driver only supports CSV for bulk loading, not {format}."
210209
raise NotImplementedError(msg)
@@ -215,15 +214,21 @@ async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, m
215214
if mode == "replace":
216215
await cursor.execute(f"DELETE FROM {table_name}")
217216

218-
# Using sync file IO here as it's a fallback path and aiofiles is not a dependency
219-
with Path(file_path).open(encoding="utf-8") as f: # noqa: ASYNC230
220-
reader = csv.reader(f, **options)
221-
header = next(reader) # Skip header
222-
placeholders = ", ".join("?" for _ in header)
223-
sql = f"INSERT INTO {table_name} VALUES ({placeholders})"
224-
data_iter = list(reader)
225-
await cursor.executemany(sql, data_iter)
226-
rowcount = cursor.rowcount
217+
# Use async storage backend to read the file
218+
file_path_str = str(file_path)
219+
backend = self._get_storage_backend(file_path_str)
220+
content = await backend.read_text_async(file_path_str, encoding="utf-8")
221+
# Parse CSV content
222+
import io
223+
224+
csv_file = io.StringIO(content)
225+
reader = csv.reader(csv_file, **options)
226+
header = next(reader) # Skip header
227+
placeholders = ", ".join("?" for _ in header)
228+
sql = f"INSERT INTO {table_name} VALUES ({placeholders})"
229+
data_iter = list(reader)
230+
await cursor.executemany(sql, data_iter)
231+
rowcount = cursor.rowcount
227232
await conn.commit()
228233
return rowcount
229234
finally:

sqlspec/adapters/bigquery/driver.py

Lines changed: 113 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1+
import contextlib
12
import datetime
23
import io
34
import logging
5+
import uuid
46
from collections.abc import Iterator
57
from decimal import Decimal
68
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast
79

810
from google.cloud.bigquery import (
911
ArrayQueryParameter,
1012
Client,
13+
ExtractJobConfig,
1114
LoadJobConfig,
1215
QueryJob,
1316
QueryJobConfig,
1417
ScalarQueryParameter,
18+
SourceFormat,
1519
WriteDisposition,
1620
)
1721
from google.cloud.bigquery.table import Row as BigQueryRow
@@ -32,6 +36,8 @@
3236
from sqlspec.utils.serializers import to_json
3337

3438
if TYPE_CHECKING:
39+
from pathlib import Path
40+
3541
from sqlglot.dialects.dialect import DialectType
3642

3743

@@ -258,23 +264,17 @@ def _run_query_job(
258264
param_value,
259265
type(param_value),
260266
)
261-
# Let BigQuery generate the job ID to avoid collisions
262-
# This is the recommended approach for production code and works better with emulators
263-
logger.warning("About to send to BigQuery - SQL: %r", sql_str)
264-
logger.warning("Query parameters in job config: %r", final_job_config.query_parameters)
265267
query_job = conn.query(sql_str, job_config=final_job_config)
266268

267269
# Get the auto-generated job ID for callbacks
268270
if self.on_job_start and query_job.job_id:
269-
try:
271+
with contextlib.suppress(Exception):
272+
# Callback errors should not interfere with job execution
270273
self.on_job_start(query_job.job_id)
271-
except Exception as e:
272-
logger.warning("Job start callback failed: %s", str(e), extra={"adapter": "bigquery"})
273274
if self.on_job_complete and query_job.job_id:
274-
try:
275+
with contextlib.suppress(Exception):
276+
# Callback errors should not interfere with job execution
275277
self.on_job_complete(query_job.job_id, query_job)
276-
except Exception as e:
277-
logger.warning("Job complete callback failed: %s", str(e), extra={"adapter": "bigquery"})
278278

279279
return query_job
280280

@@ -529,28 +529,120 @@ def _connection(self, connection: "Optional[Client]" = None) -> "Client":
529529
# BigQuery Native Export Support
530530
# ============================================================================
531531

532-
def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
533-
"""BigQuery native export implementation.
532+
def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
533+
"""BigQuery native export implementation with automatic GCS staging.
534534
535-
For local files, BigQuery doesn't support direct export, so we raise NotImplementedError
536-
to trigger the fallback mechanism that uses fetch + write.
535+
For GCS URIs, uses direct export. For other locations, automatically stages
536+
through a temporary GCS location and transfers to the final destination.
537537
538538
Args:
539539
query: SQL query to execute
540-
destination_uri: Destination URI (local file path or gs:// URI)
540+
destination_uri: Destination URI (local file path, gs:// URI, or Path object)
541541
format: Export format (parquet, csv, json, avro)
542-
**options: Additional export options
542+
**options: Additional export options including 'gcs_staging_bucket'
543543
544544
Returns:
545545
Number of rows exported
546546
547547
Raises:
548-
NotImplementedError: Always, to trigger fallback to fetch + write
548+
NotImplementedError: If no staging bucket is configured for non-GCS destinations
549549
"""
550-
# BigQuery only supports native export to GCS, not local files
551-
# By raising NotImplementedError, the mixin will fall back to fetch + write
552-
msg = "BigQuery native export only supports GCS URIs, using fallback for local files"
553-
raise NotImplementedError(msg)
550+
destination_str = str(destination_uri)
551+
552+
# If it's already a GCS URI, use direct export
553+
if destination_str.startswith("gs://"):
554+
return self._export_to_gcs_native(query, destination_str, format, **options)
555+
556+
# For non-GCS destinations, check if staging is configured
557+
staging_bucket = options.get("gcs_staging_bucket") or getattr(self.config, "gcs_staging_bucket", None)
558+
if not staging_bucket:
559+
# Fall back to fetch + write for non-GCS destinations without staging
560+
msg = "BigQuery native export requires GCS staging bucket for non-GCS destinations"
561+
raise NotImplementedError(msg)
562+
563+
# Generate temporary GCS path
564+
from datetime import timezone
565+
566+
timestamp = datetime.datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
567+
temp_filename = f"bigquery_export_{timestamp}_{uuid.uuid4().hex[:8]}.{format}"
568+
temp_gcs_uri = f"gs://{staging_bucket}/temp_exports/{temp_filename}"
569+
570+
try:
571+
# Export to temporary GCS location
572+
rows_exported = self._export_to_gcs_native(query, temp_gcs_uri, format, **options)
573+
574+
# Transfer from GCS to final destination using storage backend
575+
backend, path = self._resolve_backend_and_path(destination_str)
576+
gcs_backend = self._get_storage_backend(temp_gcs_uri)
577+
578+
# Download from GCS and upload to final destination
579+
data = gcs_backend.read_bytes(temp_gcs_uri)
580+
backend.write_bytes(path, data)
581+
582+
return rows_exported
583+
finally:
584+
# Clean up temporary file
585+
try:
586+
gcs_backend = self._get_storage_backend(temp_gcs_uri)
587+
gcs_backend.delete(temp_gcs_uri)
588+
except Exception as e:
589+
logger.warning("Failed to clean up temporary GCS file %s: %s", temp_gcs_uri, e)
590+
591+
def _export_to_gcs_native(self, query: str, gcs_uri: str, format: str, **options: Any) -> int:
592+
"""Direct BigQuery export to GCS.
593+
594+
Args:
595+
query: SQL query to execute
596+
gcs_uri: GCS destination URI (must start with gs://)
597+
format: Export format (parquet, csv, json, avro)
598+
**options: Additional export options
599+
600+
Returns:
601+
Number of rows exported
602+
"""
603+
# First, run the query and store results in a temporary table
604+
605+
temp_table_id = f"temp_export_{uuid.uuid4().hex[:8]}"
606+
dataset_id = getattr(self.connection, "default_dataset", None) or options.get("dataset", "temp")
607+
608+
# Create a temporary table with query results
609+
query_with_table = f"CREATE OR REPLACE TABLE `{dataset_id}.{temp_table_id}` AS {query}"
610+
create_job = self._run_query_job(query_with_table, [])
611+
create_job.result()
612+
613+
# Get row count
614+
count_query = f"SELECT COUNT(*) as cnt FROM `{dataset_id}.{temp_table_id}`"
615+
count_job = self._run_query_job(count_query, [])
616+
count_result = list(count_job.result())
617+
row_count = count_result[0]["cnt"] if count_result else 0
618+
619+
try:
620+
# Configure extract job
621+
extract_config = ExtractJobConfig(**options) # type: ignore[no-untyped-call]
622+
623+
# Set format
624+
format_mapping = {
625+
"parquet": SourceFormat.PARQUET,
626+
"csv": SourceFormat.CSV,
627+
"json": SourceFormat.NEWLINE_DELIMITED_JSON,
628+
"avro": SourceFormat.AVRO,
629+
}
630+
extract_config.destination_format = format_mapping.get(format, SourceFormat.PARQUET)
631+
632+
# Extract table to GCS
633+
table_ref = self.connection.dataset(dataset_id).table(temp_table_id)
634+
extract_job = self.connection.extract_table(table_ref, gcs_uri, job_config=extract_config)
635+
extract_job.result()
636+
637+
return row_count
638+
finally:
639+
# Clean up temporary table
640+
try:
641+
delete_query = f"DROP TABLE IF EXISTS `{dataset_id}.{temp_table_id}`"
642+
delete_job = self._run_query_job(delete_query, [])
643+
delete_job.result()
644+
except Exception as e:
645+
logger.warning("Failed to clean up temporary table %s: %s", temp_table_id, e)
554646

555647
# ============================================================================
556648
# BigQuery Native Arrow Support

sqlspec/adapters/duckdb/driver.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import uuid
33
from collections.abc import Generator
44
from contextlib import contextmanager
5+
from pathlib import Path
56
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
67

78
from duckdb import DuckDBPyConnection
@@ -251,7 +252,7 @@ def _has_native_capability(self, operation: str, uri: str = "", format: str = ""
251252
return True
252253
return False
253254

254-
def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
255+
def _export_native(self, query: str, destination_uri: Union[str, Path], format: str, **options: Any) -> int:
255256
conn = self._connection(None)
256257
copy_options: list[str] = []
257258

@@ -283,19 +284,21 @@ def _export_native(self, query: str, destination_uri: str, format: str, **option
283284
raise ValueError(msg)
284285

285286
options_str = f"({', '.join(copy_options)})" if copy_options else ""
286-
copy_sql = f"COPY ({query}) TO '{destination_uri}' {options_str}"
287+
copy_sql = f"COPY ({query}) TO '{destination_uri!s}' {options_str}"
287288
result_rel = conn.execute(copy_sql)
288289
result = result_rel.fetchone() if result_rel else None
289290
return result[0] if result else 0
290291

291-
def _import_native(self, source_uri: str, table_name: str, format: str, mode: str, **options: Any) -> int:
292+
def _import_native(
293+
self, source_uri: Union[str, Path], table_name: str, format: str, mode: str, **options: Any
294+
) -> int:
292295
conn = self._connection(None)
293296
if format == "parquet":
294-
read_func = f"read_parquet('{source_uri}')"
297+
read_func = f"read_parquet('{source_uri!s}')"
295298
elif format == "csv":
296-
read_func = f"read_csv_auto('{source_uri}')"
299+
read_func = f"read_csv_auto('{source_uri!s}')"
297300
elif format == "json":
298-
read_func = f"read_json_auto('{source_uri}')"
301+
read_func = f"read_json_auto('{source_uri!s}')"
299302
else:
300303
msg = f"Unsupported format for DuckDB native import: {format}"
301304
raise ValueError(msg)
@@ -320,16 +323,16 @@ def _import_native(self, source_uri: str, table_name: str, format: str, mode: st
320323
return int(count_result[0]) if count_result else 0
321324

322325
def _read_parquet_native(
323-
self, source_uri: str, columns: Optional[list[str]] = None, **options: Any
326+
self, source_uri: Union[str, Path], columns: Optional[list[str]] = None, **options: Any
324327
) -> "SQLResult[dict[str, Any]]":
325328
conn = self._connection(None)
326329
if isinstance(source_uri, list):
327330
file_list = "[" + ", ".join(f"'{f}'" for f in source_uri) + "]"
328331
read_func = f"read_parquet({file_list})"
329-
elif "*" in source_uri or "?" in source_uri:
330-
read_func = f"read_parquet('{source_uri}')"
332+
elif "*" in str(source_uri) or "?" in str(source_uri):
333+
read_func = f"read_parquet('{source_uri!s}')"
331334
else:
332-
read_func = f"read_parquet('{source_uri}')"
335+
read_func = f"read_parquet('{source_uri!s}')"
333336

334337
column_list = ", ".join(columns) if columns else "*"
335338
query = f"SELECT {column_list} FROM {read_func}"
@@ -353,7 +356,9 @@ def _read_parquet_native(
353356
statement=SQL(query), data=rows, column_names=column_names, rows_affected=num_rows, operation_type="SELECT"
354357
)
355358

356-
def _write_parquet_native(self, data: Union[str, "ArrowTable"], destination_uri: str, **options: Any) -> None:
359+
def _write_parquet_native(
360+
self, data: Union[str, "ArrowTable"], destination_uri: Union[str, Path], **options: Any
361+
) -> None:
357362
conn = self._connection(None)
358363
copy_options: list[str] = ["FORMAT PARQUET"]
359364
if "compression" in options:
@@ -364,13 +369,13 @@ def _write_parquet_native(self, data: Union[str, "ArrowTable"], destination_uri:
364369
options_str = f"({', '.join(copy_options)})"
365370

366371
if isinstance(data, str):
367-
copy_sql = f"COPY ({data}) TO '{destination_uri}' {options_str}"
372+
copy_sql = f"COPY ({data}) TO '{destination_uri!s}' {options_str}"
368373
conn.execute(copy_sql)
369374
else:
370375
temp_name = f"_arrow_data_{uuid.uuid4().hex[:8]}"
371376
conn.register(temp_name, data)
372377
try:
373-
copy_sql = f"COPY {temp_name} TO '{destination_uri}' {options_str}"
378+
copy_sql = f"COPY {temp_name} TO '{destination_uri!s}' {options_str}"
374379
conn.execute(copy_sql)
375380
finally:
376381
with contextlib.suppress(Exception):

0 commit comments

Comments
 (0)