Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ def cache_redis_url(self) -> str:
creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role
return creds["cache-url"]

# Check if we're in an onprem environment with direct Redis access
if os.environ.get('ONPREM_REDIS_HOST'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we not pass this in via config in the same way as the other redis configs?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

# Onprem Redis configuration
redis_host = os.environ.get('ONPREM_REDIS_HOST')
redis_port = os.environ.get('ONPREM_REDIS_PORT', '6379')
redis_password = os.environ.get('ONPREM_REDIS_PASSWORD')

if redis_password:
return f"redis://:{redis_password}@{redis_host}:{redis_port}/0"
else:
return f"redis://{redis_host}:{redis_port}/0"

# Azure Redis configuration (existing logic)
assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure"
username = os.getenv("AZURE_OBJECT_ID")
token = DefaultAzureCredential().get_token("https://redis.azure.com/.default")
Expand Down
7 changes: 6 additions & 1 deletion model-engine/model_engine_server/core/aws/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,17 @@ def assume_role(role_arn: str, role_session_name: Optional[str] = None) -> AwsCr
)


def session(role: Optional[str], session_type: SessionT = Session) -> SessionT:
def session(role: Optional[str], session_type: SessionT = Session) -> Optional[SessionT]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't need to touch this; this should only be used if cloud_provider == 'aws'

"""Obtain an AWS session using an arbitrary caller-specified role.

:param:`session_type` defines the type of session to return. Most users will use
the default boto3 type. Some users required a special type (e.g aioboto3 session).
"""
# Check if AWS is disabled
if os.environ.get('DISABLE_AWS') == 'true':
logger.warning(f"AWS disabled - skipping role assumption (ignoring: {role})")
return None

# Do not assume roles in CIRCLECI
if os.getenv("CIRCLECI"):
logger.warning(f"In circleci, not assuming role (ignoring: {role})")
Expand Down
6 changes: 6 additions & 0 deletions model-engine/model_engine_server/core/aws/secrets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""AWS secrets module."""

import json
import os
from functools import lru_cache
from typing import Optional

Expand All @@ -14,6 +15,11 @@

@lru_cache(maxsize=2)
def get_key_file(secret_name: str, aws_profile: Optional[str] = None):
# Check if AWS Secrets Manager is disabled
if os.environ.get('DISABLE_AWS_SECRETS_MANAGER') == 'true':
logger.warning(f"AWS Secrets Manager disabled - cannot retrieve secret: {secret_name}")
return {}

if aws_profile is not None:
session = boto3.Session(profile_name=aws_profile)
secret_manager = session.client("secretsmanager", region_name=infra_config().default_region)
Expand Down
33 changes: 22 additions & 11 deletions model-engine/model_engine_server/core/celery/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,18 +530,29 @@ def _get_backend_url_and_conf(
# use db_num=1 for backend to differentiate from broker
backend_url = get_redis_endpoint(1)
elif backend_protocol == "s3":
backend_url = "s3://"
if aws_role is None:
aws_session = session(infra_config().profile_ml_worker)
# Check if AWS is disabled - if so, fall back to Redis backend
if os.environ.get('DISABLE_AWS') == 'true':
logger.warning("AWS disabled - falling back to Redis backend instead of S3")
backend_url = get_redis_endpoint(1)
else:
aws_session = session(aws_role)
out_conf_changes.update(
{
"s3_boto3_session": aws_session,
"s3_bucket": s3_bucket,
"s3_base_path": s3_base_path,
}
)
backend_url = "s3://"
if aws_role is None:
aws_session = session(infra_config().profile_ml_worker)
else:
aws_session = session(aws_role)

# If AWS is disabled, session will be None - fall back to Redis
if aws_session is None:
logger.warning("AWS session is None - falling back to Redis backend")
backend_url = get_redis_endpoint(1)
else:
out_conf_changes.update(
{
"s3_boto3_session": aws_session,
"s3_bucket": s3_bucket,
"s3_base_path": s3_base_path,
}
)
elif backend_protocol == "abs":
backend_url = f"azureblockblob://{os.getenv('ABS_ACCOUNT_NAME')}"
else:
Expand Down
21 changes: 18 additions & 3 deletions model-engine/model_engine_server/core/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,24 @@ def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Lo
raise ValueError("Name must be a non-empty string.")
logger = logging.getLogger(name)
logger.setLevel(log_level)
logging.basicConfig(
format=LOG_FORMAT,
)

# Thread-safe logging configuration - only configure if not already configured
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, was this manifesting in a particular error?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there was a specific recursive logging error causing worker crashes:

RuntimeError: reentrant call inside <_io.BufferedWriter name=''>

This occurred during Gunicorn worker startup when multiple processes tried to initialize logging simultaneously, causing thread-unsafe logging configuration and race conditions. The error led to worker crashes, which then triggered the WORKER TIMEOUT errors we were seeing.

The issue was that multiple Gunicorn workers starting at the same time would compete to write to stderr during logging setup, causing a reentrant call error that crashed the worker processes.

if not logger.handlers:
# Use a lock to prevent race conditions in multi-threaded environments
import threading
with threading.Lock():
if not logger.handlers: # Double-check after acquiring lock
# Configure basic logging only if not already configured
if not logging.getLogger().handlers:
logging.basicConfig(
format=LOG_FORMAT,
)
# Add handler to this specific logger if needed
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(LOG_FORMAT))
logger.addHandler(handler)

return logger


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import argparse
import os
import subprocess
from typing import List

Expand All @@ -14,18 +15,25 @@ def start_gunicorn_server(port: int, num_workers: int, debug: bool) -> None:
additional_args: List[str] = []
if debug:
additional_args.extend(["--reload", "--timeout", "0"])

# Use environment variables for configuration with fallbacks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

timeout = int(os.environ.get('WORKER_TIMEOUT', os.environ.get('GUNICORN_TIMEOUT', 60)))
graceful_timeout = int(os.environ.get('GUNICORN_GRACEFUL_TIMEOUT', timeout))
keep_alive = int(os.environ.get('GUNICORN_KEEP_ALIVE', 2))
worker_class = os.environ.get('GUNICORN_WORKER_CLASS', 'model_engine_server.api.worker.LaunchWorker')

command = [
"gunicorn",
"--bind",
f"[::]:{port}",
"--timeout",
"60",
str(timeout),
"--graceful-timeout",
"60",
str(graceful_timeout),
"--keep-alive",
"2",
str(keep_alive),
"--worker-class",
"model_engine_server.api.worker.LaunchWorker",
worker_class,
"--workers",
f"{num_workers}",
*additional_args,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
from typing import Any, Dict, List, Optional

import botocore
Expand All @@ -18,28 +19,61 @@
logger = make_logger(logger_name())
backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3"

celery_redis = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.REDIS.value),
backend_protocol=backend_protocol,
)
celery_redis_24h = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.REDIS.value),
task_visibility=TaskVisibility.VISIBILITY_24H,
backend_protocol=backend_protocol,
)
celery_sqs = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.SQS.value),
backend_protocol=backend_protocol,
)
celery_servicebus = celery_app(
None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol
)
# Initialize celery apps lazily to avoid import-time AWS session creation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why we're not running into similar issues for our other non-AWS environments

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Prem - This is likely due to no creds and the import failing
Container starts with NO AWS credentials
Python import hits celery_task_queue_gateway.py:19
boto3.Session() creation fails immediately
Import exception → container crash before application even starts

celery_redis = None
celery_redis_24h = None
celery_sqs = None
celery_servicebus = None

def _get_celery_redis():
global celery_redis
if celery_redis is None:
celery_redis = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.REDIS.value),
backend_protocol=backend_protocol,
)
return celery_redis

def _get_celery_redis_24h():
global celery_redis_24h
if celery_redis_24h is None:
celery_redis_24h = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.REDIS.value),
task_visibility=TaskVisibility.VISIBILITY_24H,
backend_protocol=backend_protocol,
)
return celery_redis_24h

def _get_celery_sqs():
global celery_sqs
if celery_sqs is None:
# Check if SQS broker is disabled or if we're forcing Redis
if os.environ.get('DISABLE_SQS_BROKER') == 'true' or os.environ.get('FORCE_CELERY_REDIS') == 'true':
logger.warning("SQS broker disabled - using Redis instead")
return _get_celery_redis()
celery_sqs = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.SQS.value),
backend_protocol=backend_protocol,
)
return celery_sqs

def _get_celery_servicebus():
global celery_servicebus
if celery_servicebus is None:
# Check if ServiceBus broker is disabled or if we're forcing Redis
if os.environ.get('DISABLE_SERVICEBUS_BROKER') == 'true' or os.environ.get('FORCE_CELERY_REDIS') == 'true':
logger.warning("ServiceBus broker disabled - using Redis instead")
return _get_celery_redis()
celery_servicebus = celery_app(
None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol
)
return celery_servicebus


class CeleryTaskQueueGateway(TaskQueueGateway):
Expand All @@ -55,13 +89,13 @@ def __init__(self, broker_type: BrokerType, tracing_gateway: TracingGateway):

def _get_celery_dest(self):
if self.broker_type == BrokerType.SQS:
return celery_sqs
return _get_celery_sqs()
elif self.broker_type == BrokerType.REDIS_24H:
return celery_redis_24h
return _get_celery_redis_24h()
elif self.broker_type == BrokerType.REDIS:
return celery_redis
return _get_celery_redis()
else:
return celery_servicebus
return _get_celery_servicebus()

def send_task(
self,
Expand Down