From 17cd753aa1a6da670346df9db3bea5fb4e3c09f8 Mon Sep 17 00:00:00 2001 From: Gianluigi Mucciolo Date: Thu, 28 Aug 2025 23:49:04 +0200 Subject: [PATCH 1/5] feat(athena): add start_query_executions for async multi-query execution Introduce `wr.athena.start_query_executions` as a parallelized variant of `start_query_execution`. It allows submitting multiple queries in one call, with support for: - Sequential or threaded submission (`use_threads`) - Lazy or eager consumption of results (`as_iterator`) - Per-query `client_request_token` (string or list) - Optional workgroup checks (`check_workgroup`, `enforce_workgroup`) - Full Athena cache integration This improves performance when dispatching batches of queries by reducing workgroup lookups and enabling concurrent execution. --- awswrangler/athena/__init__.py | 2 + awswrangler/athena/_executions.py | 192 ++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+) diff --git a/awswrangler/athena/__init__.py b/awswrangler/athena/__init__.py index 0556272dc..380bcf627 100644 --- a/awswrangler/athena/__init__.py +++ b/awswrangler/athena/__init__.py @@ -4,6 +4,7 @@ get_query_execution, stop_query_execution, start_query_execution, + start_query_executions, wait_query, ) from awswrangler.athena._spark import create_spark_session, run_spark_calculation @@ -53,6 +54,7 @@ "create_ctas_table", "show_create_table", "start_query_execution", + "start_query_executions", "stop_query_execution", "unload", "wait_query", diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index b2d3f518a..bb867179c 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -10,12 +10,15 @@ cast, ) +import os import boto3 import botocore from typing_extensions import Literal +from concurrent.futures import ThreadPoolExecutor from awswrangler import _utils, exceptions, typing from awswrangler._config import apply_configs +from functools import reduce from ._cache import _CacheInfo, _check_for_cached_results from ._utils import ( @@ -168,6 +171,195 @@ def start_query_execution( return query_execution_id +@apply_configs +def start_query_executions( + sqls: list[str], + database: str | None = None, + s3_output: str | None = None, + workgroup: str = "primary", + encryption: str | None = None, + kms_key: str | None = None, + params: dict[str, Any] | list[str] | None = None, + paramstyle: Literal["qmark", "named"] = "named", + boto3_session: boto3.Session | None = None, + client_request_token: str | None = None, + athena_cache_settings: typing.AthenaCacheSettings | None = None, + athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, + data_source: str | None = None, + wait: bool = False, + check_workgroup: bool = True, + enforce_workgroup: bool = False, + as_iterator: bool = False, + use_threads: bool | int = False +) -> list[str] | list[dict[str, Any]]: + """ + Start multiple SQL queries against Amazon Athena. + + This function is the multi-query variant of ``start_query_execution``. + It supports caching, idempotent request tokens, workgroup configuration, + sequential or parallel execution, and lazy or eager iteration. + + Parameters + ---------- + sqls : list[str] + List of SQL queries to execute. + database : str, optional + AWS Glue/Athena database name. + s3_output : str, optional + S3 path where query results will be stored. + workgroup : str, default 'primary' + Athena workgroup name. + encryption : str, optional + One of {'SSE_S3', 'SSE_KMS', 'CSE_KMS'}. + kms_key : str, optional + KMS key ARN/ID, required if using KMS-based encryption. + params : dict or list, optional + Query parameters. Behavior depends on ``paramstyle``. + paramstyle : {'named', 'qmark'}, default 'named' + Parameter substitution style: + - 'named': ``{"name": "value"}`` and query must use ``:name``. + - 'qmark': list of values, substituted sequentially. + boto3_session : boto3.Session, optional + Existing boto3 session. A new session will be created if None. + client_request_token : str | list[str], optional + Idempotency token(s) for Athena: + - If a string: suffixed with an index to generate unique tokens. + - If a list: must have same length as ``sqls``. + - If None: no token provided (duplicate submissions possible). + Tokens are padded/truncated to comply with Athena’s requirement (32–128 chars). + athena_cache_settings : dict, optional + Wrangler cache settings to reuse results when possible. + athena_query_wait_polling_delay : float, default 1.0 + Interval in seconds between query status checks when waiting. + data_source : str, optional + Data catalog name (default 'AwsDataCatalog'). + wait : bool, default False + If True, block until queries complete and return their execution details. + If False, return query IDs immediately. + check_workgroup : bool, default True + If True, call GetWorkGroup once to retrieve workgroup configuration. + If False, build a workgroup config from provided parameters (faster, fewer API calls). + enforce_workgroup : bool, default False + If True, mark the dummy workgroup config as "enforced" when skipping GetWorkGroup. + as_iterator : bool, default False + If True, return a lazy iterator instead of a list. + use_threads : bool | int, default False + Controls parallelism: + - False: submit queries sequentially. + - True: use ``os.cpu_count()`` worker threads. + - int: number of worker threads to use. + + Returns + ------- + list[str] | list[dict[str, Any]] | Iterator + - If ``wait=False``: list or iterator of query execution IDs. + - If ``wait=True``: list or iterator of query execution metadata dicts. + + Examples + -------- + Sequential, no wait: + >>> qids = wr.athena.start_query_executions( + ... sqls=["SELECT 1", "SELECT 2"], + ... database="default", + ... s3_output="s3://my-bucket/results/", + ... ) + >>> print(list(qids)) + ['abc-123...', 'def-456...'] + + Parallel execution with 8 threads: + >>> qids = wr.athena.start_query_executions( + ... sqls=["SELECT 1", "SELECT 2", "SELECT 3"], + ... database="default", + ... s3_output="s3://my-bucket/results/", + ... use_threads=8, + ... ) + + Waiting for completion and retrieving metadata: + >>> results = wr.athena.start_query_executions( + ... sqls=["SELECT 1"], + ... database="default", + ... s3_output="s3://my-bucket/results/", + ... wait=True + ... ) + >>> print(results[0]["Status"]["State"]) + 'SUCCEEDED' + """ + + session = boto3_session or boto3.Session() + client = session.client("athena") + + if isinstance(client_request_token, list): + if len(client_request_token) != len(sqls): + raise ValueError("Length of client_request_token list must match number of queries in sqls") + tokens = client_request_token + elif isinstance(client_request_token, str): + tokens = [f"{client_request_token}-{i}".ljust(32, "x")[:128] for i in range(len(sqls))] + else: + tokens = [None] * len(sqls) + + formatted_queries = list(map(lambda q: _apply_formatter(q, params, paramstyle), sqls)) + + if check_workgroup: + wg_config: _WorkGroupConfig = _utils._get_workgroup_config(session=session, workgroup=workgroup) + else: + wg_config = _WorkGroupConfig( + enforced=enforce_workgroup, + s3_output=s3_output, + encryption=encryption, + kms_key=kms_key, + ) + + def _submit(item): + (q, execution_params), token = item + + if token is None and athena_cache_settings is not None: + cache_info = _executions._check_for_cached_results( + sql=q, + boto3_session=session, + workgroup=workgroup, + athena_cache_settings=athena_cache_settings, + ) + if cache_info.has_valid_cache and cache_info.query_execution_id is not None: + return cache_info.query_execution_id + + return _start_query_execution( + sql=q, + wg_config=wg_config, + database=database, + data_source=data_source, + s3_output=s3_output, + workgroup=workgroup, + encryption=encryption, + kms_key=kms_key, + execution_params=execution_params, + client_request_token=token, + boto3_session=session, + ) + + items = list(zip(formatted_queries, tokens)) + + if use_threads is False: + query_ids = map(_submit, items) + else: + max_workers = ( + os.cpu_count() or 4 if use_threads is True else int(use_threads) + ) + executor = ThreadPoolExecutor(max_workers=max_workers) + query_ids = executor.map(_submit, items) + + if wait: + results_iter = map( + lambda qid: wait_query( + query_execution_id=qid, + boto3_session=session, + athena_query_wait_polling_delay=athena_query_wait_polling_delay, + ), + query_ids, + ) + return results_iter if as_iterator else list(results_iter) + + return query_ids if as_iterator else list(query_ids) + def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None: """Stop a query execution. From b6e4d881503421bacf0869595e692134427b8d1a Mon Sep 17 00:00:00 2001 From: Gianluigi Mucciolo Date: Fri, 29 Aug 2025 00:02:42 +0200 Subject: [PATCH 2/5] feat(athena): improve start_query_executions with simplified tokens and parallel wait - Simplified client_request_token handling: - Removed manual padding/truncation. - Let Athena enforce length constraints. - Tokens generated as `-` or provided as list. - Improved wait logic: - Added optional wait handling directly inside _submit. - Queries can now be waited in parallel with submission (reduced overhead). - Configurable default threads: - Replaced hardcoded defaults with os.cpu_count(). - Added support for AWSWRANGLER_THREADS_DEFAULT env var override. --- awswrangler/athena/_executions.py | 139 +++++++++++------------------- 1 file changed, 52 insertions(+), 87 deletions(-) diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index bb867179c..12e23d596 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -32,6 +32,7 @@ _logger: logging.Logger = logging.getLogger(__name__) +_DEFAULT_MAX_WORKERS = max(4, os.cpu_count() or 4) @apply_configs def start_query_execution( @@ -179,25 +180,25 @@ def start_query_executions( workgroup: str = "primary", encryption: str | None = None, kms_key: str | None = None, - params: dict[str, Any] | list[str] | None = None, + params: dict[str, typing.Any] | list[str] | None = None, paramstyle: Literal["qmark", "named"] = "named", boto3_session: boto3.Session | None = None, - client_request_token: str | None = None, + client_request_token: str | list[str] | None = None, athena_cache_settings: typing.AthenaCacheSettings | None = None, - athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, + athena_query_wait_polling_delay: float = 1.0, data_source: str | None = None, wait: bool = False, check_workgroup: bool = True, enforce_workgroup: bool = False, as_iterator: bool = False, - use_threads: bool | int = False -) -> list[str] | list[dict[str, Any]]: + use_threads: bool | int = False, +) -> list[str] | list[dict[str, typing.Any]]: """ Start multiple SQL queries against Amazon Athena. - This function is the multi-query variant of ``start_query_execution``. - It supports caching, idempotent request tokens, workgroup configuration, - sequential or parallel execution, and lazy or eager iteration. + Each query can optionally use Athena's result cache and idempotent request tokens. + Submissions can be sequential or parallel, and each query can be waited on + individually (inside its submission thread) if ``wait=True``. Parameters ---------- @@ -216,91 +217,51 @@ def start_query_executions( params : dict or list, optional Query parameters. Behavior depends on ``paramstyle``. paramstyle : {'named', 'qmark'}, default 'named' - Parameter substitution style: - - 'named': ``{"name": "value"}`` and query must use ``:name``. - - 'qmark': list of values, substituted sequentially. + Parameter substitution style. boto3_session : boto3.Session, optional Existing boto3 session. A new session will be created if None. client_request_token : str | list[str], optional - Idempotency token(s) for Athena: - - If a string: suffixed with an index to generate unique tokens. - - If a list: must have same length as ``sqls``. - - If None: no token provided (duplicate submissions possible). - Tokens are padded/truncated to comply with Athena’s requirement (32–128 chars). + Idempotency token(s). If a string, suffixed with query index. athena_cache_settings : dict, optional - Wrangler cache settings to reuse results when possible. + Wrangler cache settings for query result reuse. athena_query_wait_polling_delay : float, default 1.0 - Interval in seconds between query status checks when waiting. + Interval between status checks when waiting for queries. data_source : str, optional Data catalog name (default 'AwsDataCatalog'). wait : bool, default False - If True, block until queries complete and return their execution details. - If False, return query IDs immediately. + If True, block until each query completes. check_workgroup : bool, default True - If True, call GetWorkGroup once to retrieve workgroup configuration. - If False, build a workgroup config from provided parameters (faster, fewer API calls). + If True, fetch workgroup config from Athena. enforce_workgroup : bool, default False - If True, mark the dummy workgroup config as "enforced" when skipping GetWorkGroup. + If True, enforce workgroup config even when skipping fetch. as_iterator : bool, default False - If True, return a lazy iterator instead of a list. + If True, return an iterator instead of a list. use_threads : bool | int, default False - Controls parallelism: - - False: submit queries sequentially. - - True: use ``os.cpu_count()`` worker threads. - - int: number of worker threads to use. + Parallelism: + - False: sequential execution + - True: ``os.cpu_count()`` threads + - int: number of worker threads Returns ------- - list[str] | list[dict[str, Any]] | Iterator - - If ``wait=False``: list or iterator of query execution IDs. - - If ``wait=True``: list or iterator of query execution metadata dicts. - - Examples - -------- - Sequential, no wait: - >>> qids = wr.athena.start_query_executions( - ... sqls=["SELECT 1", "SELECT 2"], - ... database="default", - ... s3_output="s3://my-bucket/results/", - ... ) - >>> print(list(qids)) - ['abc-123...', 'def-456...'] - - Parallel execution with 8 threads: - >>> qids = wr.athena.start_query_executions( - ... sqls=["SELECT 1", "SELECT 2", "SELECT 3"], - ... database="default", - ... s3_output="s3://my-bucket/results/", - ... use_threads=8, - ... ) - - Waiting for completion and retrieving metadata: - >>> results = wr.athena.start_query_executions( - ... sqls=["SELECT 1"], - ... database="default", - ... s3_output="s3://my-bucket/results/", - ... wait=True - ... ) - >>> print(results[0]["Status"]["State"]) - 'SUCCEEDED' + list[str] | list[dict] | Iterator + QueryExecutionIds or execution metadata dicts if ``wait=True``. """ - session = boto3_session or boto3.Session() - client = session.client("athena") if isinstance(client_request_token, list): if len(client_request_token) != len(sqls): raise ValueError("Length of client_request_token list must match number of queries in sqls") tokens = client_request_token elif isinstance(client_request_token, str): - tokens = [f"{client_request_token}-{i}".ljust(32, "x")[:128] for i in range(len(sqls))] + tokens = [f"{client_request_token}-{i}" for i in range(len(sqls))] else: tokens = [None] * len(sqls) formatted_queries = list(map(lambda q: _apply_formatter(q, params, paramstyle), sqls)) if check_workgroup: - wg_config: _WorkGroupConfig = _utils._get_workgroup_config(session=session, workgroup=workgroup) + wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup) else: wg_config = _WorkGroupConfig( enforced=enforce_workgroup, @@ -309,20 +270,28 @@ def start_query_executions( kms_key=kms_key, ) - def _submit(item): + def _submit(item: tuple[tuple[str, list[str] | None], str | None]): (q, execution_params), token = item if token is None and athena_cache_settings is not None: - cache_info = _executions._check_for_cached_results( + cache_info = _check_for_cached_results( sql=q, boto3_session=session, workgroup=workgroup, athena_cache_settings=athena_cache_settings, ) if cache_info.has_valid_cache and cache_info.query_execution_id is not None: - return cache_info.query_execution_id - - return _start_query_execution( + return ( + wait_query( + query_execution_id=cache_info.query_execution_id, + boto3_session=session, + athena_query_wait_polling_delay=athena_query_wait_polling_delay, + ) + if wait + else cache_info.query_execution_id + ) + + qid = _start_query_execution( sql=q, wg_config=wg_config, database=database, @@ -336,29 +305,25 @@ def _submit(item): boto3_session=session, ) + if wait: + return wait_query( + query_execution_id=qid, + boto3_session=session, + athena_query_wait_polling_delay=athena_query_wait_polling_delay, + ) + + return qid + items = list(zip(formatted_queries, tokens)) if use_threads is False: - query_ids = map(_submit, items) + results = map(_submit, items) else: - max_workers = ( - os.cpu_count() or 4 if use_threads is True else int(use_threads) - ) - executor = ThreadPoolExecutor(max_workers=max_workers) - query_ids = executor.map(_submit, items) - - if wait: - results_iter = map( - lambda qid: wait_query( - query_execution_id=qid, - boto3_session=session, - athena_query_wait_polling_delay=athena_query_wait_polling_delay, - ), - query_ids, - ) - return results_iter if as_iterator else list(results_iter) + max_workers = _DEFAULT_MAX_WORKERS if use_threads is True else int(use_threads) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + results = executor.map(_submit, items) - return query_ids if as_iterator else list(query_ids) + return results if as_iterator else list(results) def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None: From 5e1c7f0029e39ce1317aa85dce165ddc03e7634d Mon Sep 17 00:00:00 2001 From: Gianluigi Mucciolo Date: Fri, 29 Aug 2025 00:11:31 +0200 Subject: [PATCH 3/5] chore: cleanup and CI adjustments - Removed unused `reduce` import from Athena module. - Applied ruff formatting to `start_query_executions`. - Fixed static check issues to pass CI. - Added ruff check on Athena tests file. --- awswrangler/athena/_executions.py | 7 ++-- awswrangler/athena/_executions.pyi | 65 ++++++++++++++++++++++++++++++ tests/unit/test_athena.py | 59 +++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 3 deletions(-) diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index 12e23d596..631f6834e 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -3,22 +3,21 @@ from __future__ import annotations import logging +import os import time +from concurrent.futures import ThreadPoolExecutor from typing import ( Any, Dict, cast, ) -import os import boto3 import botocore from typing_extensions import Literal -from concurrent.futures import ThreadPoolExecutor from awswrangler import _utils, exceptions, typing from awswrangler._config import apply_configs -from functools import reduce from ._cache import _CacheInfo, _check_for_cached_results from ._utils import ( @@ -34,6 +33,7 @@ _DEFAULT_MAX_WORKERS = max(4, os.cpu_count() or 4) + @apply_configs def start_query_execution( sql: str, @@ -172,6 +172,7 @@ def start_query_execution( return query_execution_id + @apply_configs def start_query_executions( sqls: list[str], diff --git a/awswrangler/athena/_executions.pyi b/awswrangler/athena/_executions.pyi index 5a394d916..1b133896f 100644 --- a/awswrangler/athena/_executions.pyi +++ b/awswrangler/athena/_executions.pyi @@ -58,6 +58,71 @@ def start_query_execution( data_source: str | None = ..., wait: bool, ) -> str | dict[str, Any]: ... +@overload +def start_query_executions( + sqls: list[str], + database: str | None = ..., + s3_output: str | None = ..., + workgroup: str = ..., + encryption: str | None = ..., + kms_key: str | None = ..., + params: dict[str, Any] | list[str] | None = ..., + paramstyle: Literal["qmark", "named"] = ..., + boto3_session: boto3.Session | None = ..., + client_request_token: str | list[str] | None = ..., + athena_cache_settings: typing.AthenaCacheSettings | None = ..., + athena_query_wait_polling_delay: float = ..., + data_source: str | None = ..., + wait: Literal[False] = ..., + check_workgroup: bool = ..., + enforce_workgroup: bool = ..., + as_iterator: bool = ..., + use_threads: bool | int = ..., +) -> list[str]: ... +@overload +def start_query_executions( + sqls: list[str], + *, + database: str | None = ..., + s3_output: str | None = ..., + workgroup: str = ..., + encryption: str | None = ..., + kms_key: str | None = ..., + params: dict[str, Any] | list[str] | None = ..., + paramstyle: Literal["qmark", "named"] = ..., + boto3_session: boto3.Session | None = ..., + client_request_token: str | list[str] | None = ..., + athena_cache_settings: typing.AthenaCacheSettings | None = ..., + athena_query_wait_polling_delay: float = ..., + data_source: str | None = ..., + wait: Literal[True], + check_workgroup: bool = ..., + enforce_workgroup: bool = ..., + as_iterator: bool = ..., + use_threads: bool | int = ..., +) -> list[dict[str, Any]]: ... +@overload +def start_query_executions( + sqls: list[str], + *, + database: str | None = ..., + s3_output: str | None = ..., + workgroup: str = ..., + encryption: str | None = ..., + kms_key: str | None = ..., + params: dict[str, Any] | list[str] | None = ..., + paramstyle: Literal["qmark", "named"] = ..., + boto3_session: boto3.Session | None = ..., + client_request_token: str | list[str] | None = ..., + athena_cache_settings: typing.AthenaCacheSettings | None = ..., + athena_query_wait_polling_delay: float = ..., + data_source: str | None = ..., + wait: bool, + check_workgroup: bool = ..., + enforce_workgroup: bool = ..., + as_iterator: bool = ..., + use_threads: bool | int = ..., +) -> list[str] | list[dict[str, Any]]: ... def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = ...) -> None: ... def wait_query( query_execution_id: str, diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index d747ae001..1376df356 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -1708,3 +1708,62 @@ def test_athena_date_recovery(path, glue_database, glue_table): ctas_approach=False, ) assert pandas_equals(df, df2) + + +def test_start_query_executions_ids_and_results(path, glue_database, glue_table): + # Prepare table + wr.s3.to_parquet( + df=get_df(), + path=path, + index=True, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + sqls = [ + f"SELECT * FROM {glue_table} LIMIT 1", + f"SELECT COUNT(*) FROM {glue_table}", + ] + + # Case 1: Sequential, return query IDs + qids = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=False, use_threads=False) + assert isinstance(qids, list) + assert all(isinstance(qid, str) for qid in qids) + assert len(qids) == len(sqls) + + # Case 2: Sequential, wait for results + results = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=True, use_threads=False) + assert isinstance(results, list) + assert all(isinstance(r, dict) for r in results) + assert all("Status" in r for r in results) + + # Case 3: Parallel execution with threads + results_parallel = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=True, use_threads=True) + assert isinstance(results_parallel, list) + assert all(isinstance(r, dict) for r in results_parallel) + + +def test_start_query_executions_as_iterator(path, glue_database, glue_table): + # Prepare table + wr.s3.to_parquet( + df=get_df(), + path=path, + index=True, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + sqls = [f"SELECT * FROM {glue_table} LIMIT 1"] + + # Case: as_iterator=True should return a generator-like object + qids_iter = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=False, as_iterator=True) + assert not isinstance(qids_iter, list) + qids = list(qids_iter) + assert len(qids) == 1 + assert isinstance(qids[0], str) From 89449cbfecdca36b98589ba86339451473543e2d Mon Sep 17 00:00:00 2001 From: Gianluigi Mucciolo Date: Sat, 6 Sep 2025 01:32:50 +0200 Subject: [PATCH 4/5] feat(athena): support named & qmark parameters; use generators; update docstring --- awswrangler/athena/_executions.py | 53 ++++++++++++++++++++++-------- awswrangler/athena/_executions.pyi | 6 ++-- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index 631f6834e..8ed33c6e7 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -184,22 +184,24 @@ def start_query_executions( params: dict[str, typing.Any] | list[str] | None = None, paramstyle: Literal["qmark", "named"] = "named", boto3_session: boto3.Session | None = None, - client_request_token: str | list[str] | None = None, + client_request_token: str | list[list[str]] | None = None, athena_cache_settings: typing.AthenaCacheSettings | None = None, - athena_query_wait_polling_delay: float = 1.0, + athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, data_source: str | None = None, wait: bool = False, check_workgroup: bool = True, enforce_workgroup: bool = False, as_iterator: bool = False, - use_threads: bool | int = False, + use_threads: bool | int = False ) -> list[str] | list[dict[str, typing.Any]]: """ Start multiple SQL queries against Amazon Athena. - Each query can optionally use Athena's result cache and idempotent request tokens. - Submissions can be sequential or parallel, and each query can be waited on - individually (inside its submission thread) if ``wait=True``. + This is the multi-query counterpart to ``start_query_execution``. It supports + per-query caching and idempotent client request tokens, optional workgroup + validation/enforcement, sequential or thread-pooled parallel dispatch, and + either eager (list) or lazy (iterator) consumption. If ``wait=True``, each + query may be awaited to completion within its submission thread. Parameters ---------- @@ -255,11 +257,20 @@ def start_query_executions( raise ValueError("Length of client_request_token list must match number of queries in sqls") tokens = client_request_token elif isinstance(client_request_token, str): - tokens = [f"{client_request_token}-{i}" for i in range(len(sqls))] + tokens = (f"{client_request_token}-{i}" for i in range(len(sqls))) else: tokens = [None] * len(sqls) - formatted_queries = list(map(lambda q: _apply_formatter(q, params, paramstyle), sqls)) + if paramstyle == "named": + formatted_queries = (_apply_formatter(q, params, "named") for q in sqls) + elif paramstyle == "qmark": + _params_list = params or [None] * len(sqls) + formatted_queries = ( + _apply_formatter(q, query_params, "qmark") + for q, query_params in zip(sqls, _params_list) + ) + else: + raise ValueError("paramstyle must be 'named' or 'qmark'") if check_workgroup: wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup) @@ -273,6 +284,7 @@ def start_query_executions( def _submit(item: tuple[tuple[str, list[str] | None], str | None]): (q, execution_params), token = item + _logger.debug("Executing query:\n%s", q) if token is None and athena_cache_settings is not None: cache_info = _check_for_cached_results( @@ -281,7 +293,9 @@ def _submit(item: tuple[tuple[str, list[str] | None], str | None]): workgroup=workgroup, athena_cache_settings=athena_cache_settings, ) + _logger.debug("Cache info:\n%s", cache_info) if cache_info.has_valid_cache and cache_info.query_execution_id is not None: + _logger.debug("Valid cache found. Retrieving...") return ( wait_query( query_execution_id=cache_info.query_execution_id, @@ -315,17 +329,28 @@ def _submit(item: tuple[tuple[str, list[str] | None], str | None]): return qid - items = list(zip(formatted_queries, tokens)) + items = zip(formatted_queries, tokens) if use_threads is False: results = map(_submit, items) - else: - max_workers = _DEFAULT_MAX_WORKERS if use_threads is True else int(use_threads) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - results = executor.map(_submit, items) + return results if as_iterator else list(results) - return results if as_iterator else list(results) + max_workers = _DEFAULT_MAX_WORKERS if use_threads is True else int(use_threads) + if as_iterator: + executor = ThreadPoolExecutor(max_workers=max_workers) + it = executor.map(_submit, items) + + def _iter(): + try: + yield from it + finally: + executor.shutdown(wait=True) + + return _iter() + else: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + return list(executor.map(_submit, items)) def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None: """Stop a query execution. diff --git a/awswrangler/athena/_executions.pyi b/awswrangler/athena/_executions.pyi index 1b133896f..585c28d9c 100644 --- a/awswrangler/athena/_executions.pyi +++ b/awswrangler/athena/_executions.pyi @@ -69,7 +69,7 @@ def start_query_executions( params: dict[str, Any] | list[str] | None = ..., paramstyle: Literal["qmark", "named"] = ..., boto3_session: boto3.Session | None = ..., - client_request_token: str | list[str] | None = ..., + client_request_token: str | list[list[str]] | None = ..., athena_cache_settings: typing.AthenaCacheSettings | None = ..., athena_query_wait_polling_delay: float = ..., data_source: str | None = ..., @@ -91,7 +91,7 @@ def start_query_executions( params: dict[str, Any] | list[str] | None = ..., paramstyle: Literal["qmark", "named"] = ..., boto3_session: boto3.Session | None = ..., - client_request_token: str | list[str] | None = ..., + client_request_token: str | list[list[str]] | None = ..., athena_cache_settings: typing.AthenaCacheSettings | None = ..., athena_query_wait_polling_delay: float = ..., data_source: str | None = ..., @@ -113,7 +113,7 @@ def start_query_executions( params: dict[str, Any] | list[str] | None = ..., paramstyle: Literal["qmark", "named"] = ..., boto3_session: boto3.Session | None = ..., - client_request_token: str | list[str] | None = ..., + client_request_token: str | list[list[str]] | None = ..., athena_cache_settings: typing.AthenaCacheSettings | None = ..., athena_query_wait_polling_delay: float = ..., data_source: str | None = ..., From f20d694087c1c62fb85044358e952fd7727092e3 Mon Sep 17 00:00:00 2001 From: Gianluigi Mucciolo Date: Sat, 6 Sep 2025 01:38:05 +0200 Subject: [PATCH 5/5] chore(athena): ruff/black style cleanups in _executions.py --- awswrangler/athena/_executions.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index 8ed33c6e7..ce3fb5e98 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -192,7 +192,7 @@ def start_query_executions( check_workgroup: bool = True, enforce_workgroup: bool = False, as_iterator: bool = False, - use_threads: bool | int = False + use_threads: bool | int = False, ) -> list[str] | list[dict[str, typing.Any]]: """ Start multiple SQL queries against Amazon Athena. @@ -265,10 +265,7 @@ def start_query_executions( formatted_queries = (_apply_formatter(q, params, "named") for q in sqls) elif paramstyle == "qmark": _params_list = params or [None] * len(sqls) - formatted_queries = ( - _apply_formatter(q, query_params, "qmark") - for q, query_params in zip(sqls, _params_list) - ) + formatted_queries = (_apply_formatter(q, query_params, "qmark") for q, query_params in zip(sqls, _params_list)) else: raise ValueError("paramstyle must be 'named' or 'qmark'") @@ -352,6 +349,7 @@ def _iter(): with ThreadPoolExecutor(max_workers=max_workers) as executor: return list(executor.map(_submit, items)) + def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None: """Stop a query execution.