From e19921d1b20d30b89f5a554224c7aba7117bbbb2 Mon Sep 17 00:00:00 2001 From: Ishankoradia Date: Thu, 28 Aug 2025 15:25:11 +0530 Subject: [PATCH] setting connection pool limits for warehouse connection --- ddpui/api/charts_api.py | 50 ++++++++-------- ddpui/api/filter_api.py | 20 +++++-- ddpui/core/charts/charts_service.py | 16 ++++- ddpui/datainsights/warehouse/bigquery.py | 54 +++++++++++++++-- ddpui/datainsights/warehouse/postgres.py | 52 +++++++++++++++-- .../warehouse/warehouse_factory.py | 58 +++++++++++++++++++ 6 files changed, 211 insertions(+), 39 deletions(-) diff --git a/ddpui/api/charts_api.py b/ddpui/api/charts_api.py index 7b09842cb..58cda7c2e 100644 --- a/ddpui/api/charts_api.py +++ b/ddpui/api/charts_api.py @@ -104,26 +104,30 @@ def generate_chart_data_and_config(payload: ChartDataPayload, org_warehouse, cha if payload.chart_type == "map": return generate_map_data_and_config(payload, org_warehouse, chart_id) - # Get warehouse client - warehouse = charts_service.get_warehouse_client(org_warehouse) + try: + # Get warehouse client with connection pooling + warehouse = charts_service.get_warehouse_client(org_warehouse, enable_pooling=True) - # Build query - query_builder = charts_service.build_chart_query(payload) - logger.debug(f"Query built for {chart_id_str}: {query_builder}") + # Build query + query_builder = charts_service.build_chart_query(payload) + logger.debug(f"Query built for {chart_id_str}: {query_builder}") - execute_payload = ExecuteChartQuery( - chart_type=payload.chart_type, - computation_type=payload.computation_type, - x_axis=payload.x_axis, - y_axis=payload.y_axis, - dimension_col=payload.dimension_col, - extra_dimension=payload.extra_dimension, - metrics=payload.metrics, - ) + execute_payload = ExecuteChartQuery( + chart_type=payload.chart_type, + computation_type=payload.computation_type, + x_axis=payload.x_axis, + y_axis=payload.y_axis, + dimension_col=payload.dimension_col, + extra_dimension=payload.extra_dimension, + metrics=payload.metrics, + ) - # Execute query - logger.info(f"Executing query for {chart_id_str}") - dict_results = charts_service.execute_chart_query(warehouse, query_builder, execute_payload) + # Execute query + logger.info(f"Executing query for {chart_id_str}") + dict_results = charts_service.execute_chart_query(warehouse, query_builder, execute_payload) + except TimeoutError as e: + logger.warning(f"Connection timeout for chart query: {e}") + raise HttpError(503, "Service temporarily unavailable - connection pool timeout") logger.debug(f"Query results for {chart_id_str}: {len(dict_results)} rows") # Transform data for chart @@ -172,8 +176,8 @@ def generate_map_data_and_config(payload: ChartDataPayload, org_warehouse, chart geojson = get_object_or_404(GeoJSON, id=geojson_id) - # Get warehouse client and build query - warehouse = charts_service.get_warehouse_client(org_warehouse) + # Get warehouse client with connection pooling and build query + warehouse = charts_service.get_warehouse_client(org_warehouse, enable_pooling=True) query_builder = build_map_query(payload) # Execute query @@ -389,8 +393,8 @@ def get_map_data_overlay(request, payload: MapDataOverlayPayload): extra_config=extra_config if extra_config else None, ) - # Get warehouse client and build query using standard chart service - warehouse = charts_service.get_warehouse_client(org_warehouse) + # Get warehouse client with connection pooling and build query using standard chart service + warehouse = charts_service.get_warehouse_client(org_warehouse, enable_pooling=True) query_builder = charts_service.build_chart_query(chart_payload) # Add filters if provided @@ -625,8 +629,8 @@ def generate_map_chart_data(request, payload: ChartDataPayload): if not org_warehouse: raise HttpError(400, "No warehouse configured for this organization") - # Get warehouse client - warehouse = charts_service.get_warehouse_client(org_warehouse) + # Get warehouse client with connection pooling + warehouse = charts_service.get_warehouse_client(org_warehouse, enable_pooling=True) # Build query using existing service query_builder = build_map_query(payload) diff --git a/ddpui/api/filter_api.py b/ddpui/api/filter_api.py index 5554d84bb..8ff2f9886 100644 --- a/ddpui/api/filter_api.py +++ b/ddpui/api/filter_api.py @@ -95,7 +95,7 @@ def list_schemas(request): raise HttpError(404, "Warehouse not configured") try: - warehouse_client = get_warehouse_client(org_warehouse) + warehouse_client = get_warehouse_client(org_warehouse, enable_pooling=True) # Build query using AggQueryBuilder query_builder = AggQueryBuilder() @@ -125,6 +125,9 @@ def list_schemas(request): logger.info(f"Found {len(schemas)} schemas for org {orguser.org.id}") return schemas + except TimeoutError as e: + logger.warning(f"Connection timeout for schema query: {e}") + raise HttpError(503, "Service temporarily unavailable - connection pool timeout") except Exception as e: logger.error(f"Error fetching schemas: {str(e)}") raise HttpError(500, "Error fetching schemas") @@ -142,7 +145,7 @@ def list_tables(request, schema_name: str): raise HttpError(404, "Warehouse not configured") try: - warehouse_client = get_warehouse_client(org_warehouse) + warehouse_client = get_warehouse_client(org_warehouse, enable_pooling=True) # Build query using AggQueryBuilder query_builder = AggQueryBuilder() @@ -172,6 +175,9 @@ def list_tables(request, schema_name: str): logger.info(f"Found {len(tables)} tables in schema {schema_name}") return tables + except TimeoutError as e: + logger.warning(f"Connection timeout for table query: {e}") + raise HttpError(503, "Service temporarily unavailable - connection pool timeout") except Exception as e: logger.error(f"Error fetching tables for schema {schema_name}: {str(e)}") raise HttpError(500, "Error fetching tables") @@ -191,7 +197,7 @@ def list_columns(request, schema_name: str, table_name: str): raise HttpError(404, "Warehouse not configured") try: - warehouse_client = get_warehouse_client(org_warehouse) + warehouse_client = get_warehouse_client(org_warehouse, enable_pooling=True) # Build query using AggQueryBuilder query_builder = AggQueryBuilder() @@ -247,6 +253,9 @@ def list_columns(request, schema_name: str, table_name: str): logger.info(f"Found {len(columns)} columns in table {schema_name}.{table_name}") return columns + except TimeoutError as e: + logger.warning(f"Connection timeout for column query: {e}") + raise HttpError(503, "Service temporarily unavailable - connection pool timeout") except Exception as e: logger.error(f"Error fetching columns for table {schema_name}.{table_name}: {str(e)}") raise HttpError(500, "Error fetching columns") @@ -271,7 +280,7 @@ def get_filter_preview( raise HttpError(404, "Warehouse not configured") try: - warehouse_client = get_warehouse_client(org_warehouse) + warehouse_client = get_warehouse_client(org_warehouse, enable_pooling=True) if filter_type == "value": # Get distinct values with counts for categorical filter @@ -380,6 +389,9 @@ def get_filter_preview( else: raise HttpError(400, f"Invalid filter type: {filter_type}") + except TimeoutError as e: + logger.warning(f"Connection timeout for filter preview: {e}") + raise HttpError(503, "Service temporarily unavailable - connection pool timeout") except Exception as e: logger.error(f"Error getting filter preview: {str(e)}") raise HttpError(500, "Error getting filter preview") diff --git a/ddpui/core/charts/charts_service.py b/ddpui/core/charts/charts_service.py index f9222d506..801523126 100644 --- a/ddpui/core/charts/charts_service.py +++ b/ddpui/core/charts/charts_service.py @@ -86,9 +86,19 @@ def safe_get_value(row: Dict[str, Any], key: str, null_label: Optional[str] = No return handle_null_value(value, null_label) -def get_warehouse_client(org_warehouse: OrgWarehouse) -> Warehouse: - """Get warehouse client using the standard method""" - return WarehouseFactory.get_warehouse_client(org_warehouse) +def get_warehouse_client( + org_warehouse: OrgWarehouse, enable_pooling: bool = True, connection_tier: str = "medium" +) -> Warehouse: + """Get warehouse client with optional connection pooling""" + + if enable_pooling: + # Use connection pooled factory when enabled + from ddpui.datainsights.warehouse.warehouse_factory import ConnectionPooledWarehouseFactory + + return ConnectionPooledWarehouseFactory.get_warehouse_client(org_warehouse, connection_tier) + else: + # Fall back to standard factory + return WarehouseFactory.get_warehouse_client(org_warehouse) def convert_value(value: Any, preserve_none: bool = False) -> Any: diff --git a/ddpui/datainsights/warehouse/bigquery.py b/ddpui/datainsights/warehouse/bigquery.py index b90604976..bb0f4d694 100644 --- a/ddpui/datainsights/warehouse/bigquery.py +++ b/ddpui/datainsights/warehouse/bigquery.py @@ -4,6 +4,9 @@ from sqlalchemy import inspect from sqlalchemy.types import NullType from sqlalchemy_bigquery._types import _type_map +from threading import Lock +import hashlib +import json from ddpui.datainsights.insights.insight_interface import MAP_TRANSLATE_TYPES from ddpui.datainsights.warehouse.warehouse_interface import Warehouse @@ -14,16 +17,57 @@ class BigqueryClient(Warehouse): - def __init__(self, creds: dict): + # Class-level engine cache and lock + _engines = {} + _engine_lock = Lock() + + @staticmethod + def _generate_engine_key(creds: dict, connection_config: dict = None) -> str: + """Generate a unique hash key from credentials and config""" + # Create a copy to avoid modifying original + key_data = creds.copy() + + # Add connection config to the key if provided + if connection_config: + key_data["_connection_config"] = connection_config + + # Sort keys for consistent ordering + key_string = json.dumps(key_data, sort_keys=True, default=str) + + # Generate hash + return hashlib.md5(key_string.encode()).hexdigest() + + def __init__(self, creds: dict, connection_config: dict = None): """ - Establish connection to the postgres database using sqlalchemy engine + Establish connection to the BigQuery database using sqlalchemy engine Creds come from the secrets manager + connection_config: Optional dict with connection pooling parameters """ connection_string = "bigquery://{project_id}".format(**creds) - self.engine = create_engine( - connection_string, credentials_info=creds, pool_size=5, pool_timeout=30 - ) + # Default connection pool configuration + default_pool_config = { + "pool_size": 5, + "max_overflow": 2, + "pool_timeout": 30, + "pool_recycle": 3600, + "pool_pre_ping": True, + } + + # Use provided config or defaults + pool_config = connection_config if connection_config else default_pool_config + + # Create engine key from all credentials and config + engine_key = self._generate_engine_key(creds, connection_config) + + # Use singleton pattern for engines + with self._engine_lock: + if engine_key not in self._engines: + self._engines[engine_key] = create_engine( + connection_string, credentials_info=creds, **pool_config + ) + self.engine = self._engines[engine_key] + self.inspect_obj: Inspector = inspect( self.engine ) # this will be used to fetch metadata of the database diff --git a/ddpui/datainsights/warehouse/postgres.py b/ddpui/datainsights/warehouse/postgres.py index a4db80932..49708c8dc 100644 --- a/ddpui/datainsights/warehouse/postgres.py +++ b/ddpui/datainsights/warehouse/postgres.py @@ -1,5 +1,8 @@ import tempfile from urllib.parse import quote +from threading import Lock +import hashlib +import json from sqlalchemy.engine import create_engine from sqlalchemy.engine.reflection import Inspector @@ -12,10 +15,31 @@ class PostgresClient(Warehouse): - def __init__(self, creds: dict): + # Class-level engine cache and lock + _engines = {} + _engine_lock = Lock() + + @staticmethod + def _generate_engine_key(creds: dict, connection_config: dict = None) -> str: + """Generate a unique hash key from credentials and config""" + # Create a copy to avoid modifying original + key_data = creds.copy() + + # Add connection config to the key if provided + if connection_config: + key_data["_connection_config"] = connection_config + + # Sort keys for consistent ordering + key_string = json.dumps(key_data, sort_keys=True, default=str) + + # Generate hash + return hashlib.md5(key_string.encode()).hexdigest() + + def __init__(self, creds: dict, connection_config: dict = None): """ Establish connection to the postgres database using sqlalchemy engine Creds come from the secrets manager + connection_config: Optional dict with connection pooling parameters """ creds["encoded_username"] = quote(creds["username"].strip()) creds["encoded_password"] = quote(creds["password"].strip()) @@ -55,9 +79,29 @@ def __init__(self, creds: dict): fp.write(creds["sslmode"]["ca_certificate"].encode()) connection_args["sslrootcert"] = fp.name - self.engine = create_engine( - connection_string, connect_args=connection_args, pool_size=5, pool_timeout=30 - ) + # Default connection pool configuration + default_pool_config = { + "pool_size": 5, + "max_overflow": 2, + "pool_timeout": 30, + "pool_recycle": 3600, + "pool_pre_ping": True, + } + + # Use provided config or defaults + pool_config = connection_config if connection_config else default_pool_config + + # Create engine key from all credentials and config + engine_key = self._generate_engine_key(creds, connection_config) + + # Use singleton pattern for engines + with self._engine_lock: + if engine_key not in self._engines: + self._engines[engine_key] = create_engine( + connection_string, connect_args=connection_args, **pool_config + ) + self.engine = self._engines[engine_key] + self.inspect_obj: Inspector = inspect( self.engine ) # this will be used to fetch metadata of the database diff --git a/ddpui/datainsights/warehouse/warehouse_factory.py b/ddpui/datainsights/warehouse/warehouse_factory.py index de29011ca..fbc03f544 100644 --- a/ddpui/datainsights/warehouse/warehouse_factory.py +++ b/ddpui/datainsights/warehouse/warehouse_factory.py @@ -26,3 +26,61 @@ def get_warehouse_client(cls, org_warehouse: OrgWarehouse) -> Warehouse: raise ValueError("Warehouse credentials not found") return cls.connect(creds, org_warehouse.wtype) + + +class ConnectionPooledWarehouseFactory: + """Factory for creating warehouse clients with connection pooling""" + + # Connection pool configurations per org tier + CONNECTION_CONFIGS = { + "small": { + "pool_size": 3, + "max_overflow": 1, + "pool_timeout": 30, + "pool_recycle": 3600, + "pool_pre_ping": True, + }, + "medium": { + "pool_size": 5, + "max_overflow": 2, + "pool_timeout": 30, + "pool_recycle": 3600, + "pool_pre_ping": True, + }, + "large": { + "pool_size": 10, + "max_overflow": 5, + "pool_timeout": 60, + "pool_recycle": 3600, + "pool_pre_ping": True, + }, + } + + @classmethod + def connect(cls, creds: dict, wtype: str, connection_tier: str = "medium") -> Warehouse: + """Create warehouse connection with connection pooling""" + + connection_config = cls.CONNECTION_CONFIGS.get( + connection_tier, cls.CONNECTION_CONFIGS["medium"] + ) + + if wtype == WarehouseType.POSTGRES: + return PostgresClient(creds, connection_config=connection_config) + elif wtype == WarehouseType.BIGQUERY: + return BigqueryClient(creds, connection_config=connection_config) + else: + raise ValueError(f"Warehouse type {wtype} not supported for connection pooling") + + @classmethod + def get_warehouse_client( + cls, org_warehouse: OrgWarehouse, connection_tier: str = "medium" + ) -> Warehouse: + """Get warehouse client with connection pooling""" + if not org_warehouse: + raise ValueError("Organization warehouse not configured") + + creds = retrieve_warehouse_credentials(org_warehouse) + if not creds: + raise ValueError("Warehouse credentials not found") + + return cls.connect(creds, org_warehouse.wtype, connection_tier)