Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions ddpui/api/charts_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,30 @@ def generate_chart_data_and_config(payload: ChartDataPayload, org_warehouse, cha
if payload.chart_type == "map":
return generate_map_data_and_config(payload, org_warehouse, chart_id)

# Get warehouse client
warehouse = charts_service.get_warehouse_client(org_warehouse)
try:
# Get warehouse client with connection pooling
warehouse = charts_service.get_warehouse_client(org_warehouse, enable_pooling=True)

# Build query
query_builder = charts_service.build_chart_query(payload)
logger.debug(f"Query built for {chart_id_str}: {query_builder}")
# Build query
query_builder = charts_service.build_chart_query(payload)
logger.debug(f"Query built for {chart_id_str}: {query_builder}")

execute_payload = ExecuteChartQuery(
chart_type=payload.chart_type,
computation_type=payload.computation_type,
x_axis=payload.x_axis,
y_axis=payload.y_axis,
dimension_col=payload.dimension_col,
extra_dimension=payload.extra_dimension,
metrics=payload.metrics,
)
execute_payload = ExecuteChartQuery(
chart_type=payload.chart_type,
computation_type=payload.computation_type,
x_axis=payload.x_axis,
y_axis=payload.y_axis,
dimension_col=payload.dimension_col,
extra_dimension=payload.extra_dimension,
metrics=payload.metrics,
)

# Execute query
logger.info(f"Executing query for {chart_id_str}")
dict_results = charts_service.execute_chart_query(warehouse, query_builder, execute_payload)
# Execute query
logger.info(f"Executing query for {chart_id_str}")
dict_results = charts_service.execute_chart_query(warehouse, query_builder, execute_payload)
except TimeoutError as e:
logger.warning(f"Connection timeout for chart query: {e}")
raise HttpError(503, "Service temporarily unavailable - connection pool timeout")
logger.debug(f"Query results for {chart_id_str}: {len(dict_results)} rows")

# Transform data for chart
Expand Down Expand Up @@ -172,8 +176,8 @@ def generate_map_data_and_config(payload: ChartDataPayload, org_warehouse, chart

geojson = get_object_or_404(GeoJSON, id=geojson_id)

# Get warehouse client and build query
warehouse = charts_service.get_warehouse_client(org_warehouse)
# Get warehouse client with connection pooling and build query
warehouse = charts_service.get_warehouse_client(org_warehouse, enable_pooling=True)
query_builder = build_map_query(payload)

# Execute query
Expand Down Expand Up @@ -389,8 +393,8 @@ def get_map_data_overlay(request, payload: MapDataOverlayPayload):
extra_config=extra_config if extra_config else None,
)

# Get warehouse client and build query using standard chart service
warehouse = charts_service.get_warehouse_client(org_warehouse)
# Get warehouse client with connection pooling and build query using standard chart service
warehouse = charts_service.get_warehouse_client(org_warehouse, enable_pooling=True)
query_builder = charts_service.build_chart_query(chart_payload)

# Add filters if provided
Expand Down Expand Up @@ -625,8 +629,8 @@ def generate_map_chart_data(request, payload: ChartDataPayload):
if not org_warehouse:
raise HttpError(400, "No warehouse configured for this organization")

# Get warehouse client
warehouse = charts_service.get_warehouse_client(org_warehouse)
# Get warehouse client with connection pooling
warehouse = charts_service.get_warehouse_client(org_warehouse, enable_pooling=True)

# Build query using existing service
query_builder = build_map_query(payload)
Expand Down
20 changes: 16 additions & 4 deletions ddpui/api/filter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def list_schemas(request):
raise HttpError(404, "Warehouse not configured")

try:
warehouse_client = get_warehouse_client(org_warehouse)
warehouse_client = get_warehouse_client(org_warehouse, enable_pooling=True)

# Build query using AggQueryBuilder
query_builder = AggQueryBuilder()
Expand Down Expand Up @@ -125,6 +125,9 @@ def list_schemas(request):
logger.info(f"Found {len(schemas)} schemas for org {orguser.org.id}")
return schemas

except TimeoutError as e:
logger.warning(f"Connection timeout for schema query: {e}")
raise HttpError(503, "Service temporarily unavailable - connection pool timeout")
except Exception as e:
logger.error(f"Error fetching schemas: {str(e)}")
raise HttpError(500, "Error fetching schemas")
Expand All @@ -142,7 +145,7 @@ def list_tables(request, schema_name: str):
raise HttpError(404, "Warehouse not configured")

try:
warehouse_client = get_warehouse_client(org_warehouse)
warehouse_client = get_warehouse_client(org_warehouse, enable_pooling=True)

# Build query using AggQueryBuilder
query_builder = AggQueryBuilder()
Expand Down Expand Up @@ -172,6 +175,9 @@ def list_tables(request, schema_name: str):
logger.info(f"Found {len(tables)} tables in schema {schema_name}")
return tables

except TimeoutError as e:
logger.warning(f"Connection timeout for table query: {e}")
raise HttpError(503, "Service temporarily unavailable - connection pool timeout")
except Exception as e:
logger.error(f"Error fetching tables for schema {schema_name}: {str(e)}")
raise HttpError(500, "Error fetching tables")
Expand All @@ -191,7 +197,7 @@ def list_columns(request, schema_name: str, table_name: str):
raise HttpError(404, "Warehouse not configured")

try:
warehouse_client = get_warehouse_client(org_warehouse)
warehouse_client = get_warehouse_client(org_warehouse, enable_pooling=True)

# Build query using AggQueryBuilder
query_builder = AggQueryBuilder()
Expand Down Expand Up @@ -247,6 +253,9 @@ def list_columns(request, schema_name: str, table_name: str):
logger.info(f"Found {len(columns)} columns in table {schema_name}.{table_name}")
return columns

except TimeoutError as e:
logger.warning(f"Connection timeout for column query: {e}")
raise HttpError(503, "Service temporarily unavailable - connection pool timeout")
except Exception as e:
logger.error(f"Error fetching columns for table {schema_name}.{table_name}: {str(e)}")
raise HttpError(500, "Error fetching columns")
Expand All @@ -271,7 +280,7 @@ def get_filter_preview(
raise HttpError(404, "Warehouse not configured")

try:
warehouse_client = get_warehouse_client(org_warehouse)
warehouse_client = get_warehouse_client(org_warehouse, enable_pooling=True)

if filter_type == "value":
# Get distinct values with counts for categorical filter
Expand Down Expand Up @@ -380,6 +389,9 @@ def get_filter_preview(
else:
raise HttpError(400, f"Invalid filter type: {filter_type}")

except TimeoutError as e:
logger.warning(f"Connection timeout for filter preview: {e}")
raise HttpError(503, "Service temporarily unavailable - connection pool timeout")
except Exception as e:
logger.error(f"Error getting filter preview: {str(e)}")
raise HttpError(500, "Error getting filter preview")
16 changes: 13 additions & 3 deletions ddpui/core/charts/charts_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,19 @@ def safe_get_value(row: Dict[str, Any], key: str, null_label: Optional[str] = No
return handle_null_value(value, null_label)


def get_warehouse_client(org_warehouse: OrgWarehouse) -> Warehouse:
"""Get warehouse client using the standard method"""
return WarehouseFactory.get_warehouse_client(org_warehouse)
def get_warehouse_client(
org_warehouse: OrgWarehouse, enable_pooling: bool = True, connection_tier: str = "medium"
) -> Warehouse:
"""Get warehouse client with optional connection pooling"""

if enable_pooling:
# Use connection pooled factory when enabled
from ddpui.datainsights.warehouse.warehouse_factory import ConnectionPooledWarehouseFactory

return ConnectionPooledWarehouseFactory.get_warehouse_client(org_warehouse, connection_tier)
else:
# Fall back to standard factory
return WarehouseFactory.get_warehouse_client(org_warehouse)


def convert_value(value: Any, preserve_none: bool = False) -> Any:
Expand Down
54 changes: 49 additions & 5 deletions ddpui/datainsights/warehouse/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from sqlalchemy import inspect
from sqlalchemy.types import NullType
from sqlalchemy_bigquery._types import _type_map
from threading import Lock
import hashlib
import json

from ddpui.datainsights.insights.insight_interface import MAP_TRANSLATE_TYPES
from ddpui.datainsights.warehouse.warehouse_interface import Warehouse
Expand All @@ -14,16 +17,57 @@


class BigqueryClient(Warehouse):
def __init__(self, creds: dict):
# Class-level engine cache and lock
_engines = {}
_engine_lock = Lock()

@staticmethod
def _generate_engine_key(creds: dict, connection_config: dict = None) -> str:
"""Generate a unique hash key from credentials and config"""
# Create a copy to avoid modifying original
key_data = creds.copy()

# Add connection config to the key if provided
if connection_config:
key_data["_connection_config"] = connection_config

# Sort keys for consistent ordering
key_string = json.dumps(key_data, sort_keys=True, default=str)

# Generate hash
return hashlib.md5(key_string.encode()).hexdigest()

def __init__(self, creds: dict, connection_config: dict = None):
"""
Establish connection to the postgres database using sqlalchemy engine
Establish connection to the BigQuery database using sqlalchemy engine
Creds come from the secrets manager
connection_config: Optional dict with connection pooling parameters
"""
connection_string = "bigquery://{project_id}".format(**creds)

self.engine = create_engine(
connection_string, credentials_info=creds, pool_size=5, pool_timeout=30
)
# Default connection pool configuration
default_pool_config = {
"pool_size": 5,
"max_overflow": 2,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_pre_ping": True,
}

# Use provided config or defaults
pool_config = connection_config if connection_config else default_pool_config

# Create engine key from all credentials and config
engine_key = self._generate_engine_key(creds, connection_config)

# Use singleton pattern for engines
with self._engine_lock:
if engine_key not in self._engines:
self._engines[engine_key] = create_engine(
connection_string, credentials_info=creds, **pool_config
)
self.engine = self._engines[engine_key]

self.inspect_obj: Inspector = inspect(
self.engine
) # this will be used to fetch metadata of the database
Expand Down
52 changes: 48 additions & 4 deletions ddpui/datainsights/warehouse/postgres.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import tempfile
from urllib.parse import quote
from threading import Lock
import hashlib
import json

from sqlalchemy.engine import create_engine
from sqlalchemy.engine.reflection import Inspector
Expand All @@ -12,10 +15,31 @@


class PostgresClient(Warehouse):
def __init__(self, creds: dict):
# Class-level engine cache and lock
_engines = {}
_engine_lock = Lock()

@staticmethod
def _generate_engine_key(creds: dict, connection_config: dict = None) -> str:
"""Generate a unique hash key from credentials and config"""
# Create a copy to avoid modifying original
key_data = creds.copy()

# Add connection config to the key if provided
if connection_config:
key_data["_connection_config"] = connection_config

# Sort keys for consistent ordering
key_string = json.dumps(key_data, sort_keys=True, default=str)

# Generate hash
return hashlib.md5(key_string.encode()).hexdigest()

def __init__(self, creds: dict, connection_config: dict = None):
"""
Establish connection to the postgres database using sqlalchemy engine
Creds come from the secrets manager
connection_config: Optional dict with connection pooling parameters
"""
creds["encoded_username"] = quote(creds["username"].strip())
creds["encoded_password"] = quote(creds["password"].strip())
Expand Down Expand Up @@ -55,9 +79,29 @@ def __init__(self, creds: dict):
fp.write(creds["sslmode"]["ca_certificate"].encode())
connection_args["sslrootcert"] = fp.name

self.engine = create_engine(
connection_string, connect_args=connection_args, pool_size=5, pool_timeout=30
)
# Default connection pool configuration
default_pool_config = {
"pool_size": 5,
"max_overflow": 2,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_pre_ping": True,
}

# Use provided config or defaults
pool_config = connection_config if connection_config else default_pool_config

# Create engine key from all credentials and config
engine_key = self._generate_engine_key(creds, connection_config)

# Use singleton pattern for engines
with self._engine_lock:
if engine_key not in self._engines:
self._engines[engine_key] = create_engine(
connection_string, connect_args=connection_args, **pool_config
)
self.engine = self._engines[engine_key]

self.inspect_obj: Inspector = inspect(
self.engine
) # this will be used to fetch metadata of the database
Expand Down
58 changes: 58 additions & 0 deletions ddpui/datainsights/warehouse/warehouse_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,61 @@ def get_warehouse_client(cls, org_warehouse: OrgWarehouse) -> Warehouse:
raise ValueError("Warehouse credentials not found")

return cls.connect(creds, org_warehouse.wtype)


class ConnectionPooledWarehouseFactory:
"""Factory for creating warehouse clients with connection pooling"""

# Connection pool configurations per org tier
CONNECTION_CONFIGS = {
"small": {
"pool_size": 3,
"max_overflow": 1,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_pre_ping": True,
},
"medium": {
"pool_size": 5,
"max_overflow": 2,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_pre_ping": True,
},
"large": {
"pool_size": 10,
"max_overflow": 5,
"pool_timeout": 60,
"pool_recycle": 3600,
"pool_pre_ping": True,
},
}

@classmethod
def connect(cls, creds: dict, wtype: str, connection_tier: str = "medium") -> Warehouse:
"""Create warehouse connection with connection pooling"""

connection_config = cls.CONNECTION_CONFIGS.get(
connection_tier, cls.CONNECTION_CONFIGS["medium"]
)

if wtype == WarehouseType.POSTGRES:
return PostgresClient(creds, connection_config=connection_config)
elif wtype == WarehouseType.BIGQUERY:
return BigqueryClient(creds, connection_config=connection_config)
else:
raise ValueError(f"Warehouse type {wtype} not supported for connection pooling")

@classmethod
def get_warehouse_client(
cls, org_warehouse: OrgWarehouse, connection_tier: str = "medium"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will default to medium. But, what will you pass for v2 in prod? Medium or Large?

) -> Warehouse:
"""Get warehouse client with connection pooling"""
if not org_warehouse:
raise ValueError("Organization warehouse not configured")

creds = retrieve_warehouse_credentials(org_warehouse)
if not creds:
raise ValueError("Warehouse credentials not found")

return cls.connect(creds, org_warehouse.wtype, connection_tier)