|
| 1 | +import contextlib |
1 | 2 | import datetime
|
2 | 3 | import io
|
3 | 4 | import logging
|
| 5 | +import uuid |
4 | 6 | from collections.abc import Iterator
|
5 | 7 | from decimal import Decimal
|
6 | 8 | from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast
|
7 | 9 |
|
8 | 10 | from google.cloud.bigquery import (
|
9 | 11 | ArrayQueryParameter,
|
10 | 12 | Client,
|
| 13 | + ExtractJobConfig, |
11 | 14 | LoadJobConfig,
|
12 | 15 | QueryJob,
|
13 | 16 | QueryJobConfig,
|
14 | 17 | ScalarQueryParameter,
|
| 18 | + SourceFormat, |
15 | 19 | WriteDisposition,
|
16 | 20 | )
|
17 | 21 | from google.cloud.bigquery.table import Row as BigQueryRow
|
|
32 | 36 | from sqlspec.utils.serializers import to_json
|
33 | 37 |
|
34 | 38 | if TYPE_CHECKING:
|
| 39 | + from pathlib import Path |
| 40 | + |
35 | 41 | from sqlglot.dialects.dialect import DialectType
|
36 | 42 |
|
37 | 43 |
|
@@ -258,23 +264,17 @@ def _run_query_job(
|
258 | 264 | param_value,
|
259 | 265 | type(param_value),
|
260 | 266 | )
|
261 |
| - # Let BigQuery generate the job ID to avoid collisions |
262 |
| - # This is the recommended approach for production code and works better with emulators |
263 |
| - logger.warning("About to send to BigQuery - SQL: %r", sql_str) |
264 |
| - logger.warning("Query parameters in job config: %r", final_job_config.query_parameters) |
265 | 267 | query_job = conn.query(sql_str, job_config=final_job_config)
|
266 | 268 |
|
267 | 269 | # Get the auto-generated job ID for callbacks
|
268 | 270 | if self.on_job_start and query_job.job_id:
|
269 |
| - try: |
| 271 | + with contextlib.suppress(Exception): |
| 272 | + # Callback errors should not interfere with job execution |
270 | 273 | self.on_job_start(query_job.job_id)
|
271 |
| - except Exception as e: |
272 |
| - logger.warning("Job start callback failed: %s", str(e), extra={"adapter": "bigquery"}) |
273 | 274 | if self.on_job_complete and query_job.job_id:
|
274 |
| - try: |
| 275 | + with contextlib.suppress(Exception): |
| 276 | + # Callback errors should not interfere with job execution |
275 | 277 | self.on_job_complete(query_job.job_id, query_job)
|
276 |
| - except Exception as e: |
277 |
| - logger.warning("Job complete callback failed: %s", str(e), extra={"adapter": "bigquery"}) |
278 | 278 |
|
279 | 279 | return query_job
|
280 | 280 |
|
@@ -529,28 +529,120 @@ def _connection(self, connection: "Optional[Client]" = None) -> "Client":
|
529 | 529 | # BigQuery Native Export Support
|
530 | 530 | # ============================================================================
|
531 | 531 |
|
532 |
| - def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int: |
533 |
| - """BigQuery native export implementation. |
| 532 | + def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int: |
| 533 | + """BigQuery native export implementation with automatic GCS staging. |
534 | 534 |
|
535 |
| - For local files, BigQuery doesn't support direct export, so we raise NotImplementedError |
536 |
| - to trigger the fallback mechanism that uses fetch + write. |
| 535 | + For GCS URIs, uses direct export. For other locations, automatically stages |
| 536 | + through a temporary GCS location and transfers to the final destination. |
537 | 537 |
|
538 | 538 | Args:
|
539 | 539 | query: SQL query to execute
|
540 |
| - destination_uri: Destination URI (local file path or gs:// URI) |
| 540 | + destination_uri: Destination URI (local file path, gs:// URI, or Path object) |
541 | 541 | format: Export format (parquet, csv, json, avro)
|
542 |
| - **options: Additional export options |
| 542 | + **options: Additional export options including 'gcs_staging_bucket' |
543 | 543 |
|
544 | 544 | Returns:
|
545 | 545 | Number of rows exported
|
546 | 546 |
|
547 | 547 | Raises:
|
548 |
| - NotImplementedError: Always, to trigger fallback to fetch + write |
| 548 | + NotImplementedError: If no staging bucket is configured for non-GCS destinations |
549 | 549 | """
|
550 |
| - # BigQuery only supports native export to GCS, not local files |
551 |
| - # By raising NotImplementedError, the mixin will fall back to fetch + write |
552 |
| - msg = "BigQuery native export only supports GCS URIs, using fallback for local files" |
553 |
| - raise NotImplementedError(msg) |
| 550 | + destination_str = str(destination_uri) |
| 551 | + |
| 552 | + # If it's already a GCS URI, use direct export |
| 553 | + if destination_str.startswith("gs://"): |
| 554 | + return self._export_to_gcs_native(query, destination_str, format, **options) |
| 555 | + |
| 556 | + # For non-GCS destinations, check if staging is configured |
| 557 | + staging_bucket = options.get("gcs_staging_bucket") or getattr(self.config, "gcs_staging_bucket", None) |
| 558 | + if not staging_bucket: |
| 559 | + # Fall back to fetch + write for non-GCS destinations without staging |
| 560 | + msg = "BigQuery native export requires GCS staging bucket for non-GCS destinations" |
| 561 | + raise NotImplementedError(msg) |
| 562 | + |
| 563 | + # Generate temporary GCS path |
| 564 | + from datetime import timezone |
| 565 | + |
| 566 | + timestamp = datetime.datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") |
| 567 | + temp_filename = f"bigquery_export_{timestamp}_{uuid.uuid4().hex[:8]}.{format}" |
| 568 | + temp_gcs_uri = f"gs://{staging_bucket}/temp_exports/{temp_filename}" |
| 569 | + |
| 570 | + try: |
| 571 | + # Export to temporary GCS location |
| 572 | + rows_exported = self._export_to_gcs_native(query, temp_gcs_uri, format, **options) |
| 573 | + |
| 574 | + # Transfer from GCS to final destination using storage backend |
| 575 | + backend, path = self._resolve_backend_and_path(destination_str) |
| 576 | + gcs_backend = self._get_storage_backend(temp_gcs_uri) |
| 577 | + |
| 578 | + # Download from GCS and upload to final destination |
| 579 | + data = gcs_backend.read_bytes(temp_gcs_uri) |
| 580 | + backend.write_bytes(path, data) |
| 581 | + |
| 582 | + return rows_exported |
| 583 | + finally: |
| 584 | + # Clean up temporary file |
| 585 | + try: |
| 586 | + gcs_backend = self._get_storage_backend(temp_gcs_uri) |
| 587 | + gcs_backend.delete(temp_gcs_uri) |
| 588 | + except Exception as e: |
| 589 | + logger.warning("Failed to clean up temporary GCS file %s: %s", temp_gcs_uri, e) |
| 590 | + |
| 591 | + def _export_to_gcs_native(self, query: str, gcs_uri: str, format: str, **options: Any) -> int: |
| 592 | + """Direct BigQuery export to GCS. |
| 593 | +
|
| 594 | + Args: |
| 595 | + query: SQL query to execute |
| 596 | + gcs_uri: GCS destination URI (must start with gs://) |
| 597 | + format: Export format (parquet, csv, json, avro) |
| 598 | + **options: Additional export options |
| 599 | +
|
| 600 | + Returns: |
| 601 | + Number of rows exported |
| 602 | + """ |
| 603 | + # First, run the query and store results in a temporary table |
| 604 | + |
| 605 | + temp_table_id = f"temp_export_{uuid.uuid4().hex[:8]}" |
| 606 | + dataset_id = getattr(self.connection, "default_dataset", None) or options.get("dataset", "temp") |
| 607 | + |
| 608 | + # Create a temporary table with query results |
| 609 | + query_with_table = f"CREATE OR REPLACE TABLE `{dataset_id}.{temp_table_id}` AS {query}" |
| 610 | + create_job = self._run_query_job(query_with_table, []) |
| 611 | + create_job.result() |
| 612 | + |
| 613 | + # Get row count |
| 614 | + count_query = f"SELECT COUNT(*) as cnt FROM `{dataset_id}.{temp_table_id}`" |
| 615 | + count_job = self._run_query_job(count_query, []) |
| 616 | + count_result = list(count_job.result()) |
| 617 | + row_count = count_result[0]["cnt"] if count_result else 0 |
| 618 | + |
| 619 | + try: |
| 620 | + # Configure extract job |
| 621 | + extract_config = ExtractJobConfig(**options) # type: ignore[no-untyped-call] |
| 622 | + |
| 623 | + # Set format |
| 624 | + format_mapping = { |
| 625 | + "parquet": SourceFormat.PARQUET, |
| 626 | + "csv": SourceFormat.CSV, |
| 627 | + "json": SourceFormat.NEWLINE_DELIMITED_JSON, |
| 628 | + "avro": SourceFormat.AVRO, |
| 629 | + } |
| 630 | + extract_config.destination_format = format_mapping.get(format, SourceFormat.PARQUET) |
| 631 | + |
| 632 | + # Extract table to GCS |
| 633 | + table_ref = self.connection.dataset(dataset_id).table(temp_table_id) |
| 634 | + extract_job = self.connection.extract_table(table_ref, gcs_uri, job_config=extract_config) |
| 635 | + extract_job.result() |
| 636 | + |
| 637 | + return row_count |
| 638 | + finally: |
| 639 | + # Clean up temporary table |
| 640 | + try: |
| 641 | + delete_query = f"DROP TABLE IF EXISTS `{dataset_id}.{temp_table_id}`" |
| 642 | + delete_job = self._run_query_job(delete_query, []) |
| 643 | + delete_job.result() |
| 644 | + except Exception as e: |
| 645 | + logger.warning("Failed to clean up temporary table %s: %s", temp_table_id, e) |
554 | 646 |
|
555 | 647 | # ============================================================================
|
556 | 648 | # BigQuery Native Arrow Support
|
|
0 commit comments