Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions charts/model-engine/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,18 @@ env:
- name: CIRCLECI
value: "true"
{{- end }}
{{- if .Values.gunicorn }}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

anecdotally, we found it a lot easier to performance tune pure uvicorn, so we actually migrated most usage of gunicorn back to uvicorn. That being said, won't block your usage of it

- name: WORKER_TIMEOUT
value: {{ .Values.gunicorn.workerTimeout | quote }}
- name: GUNICORN_TIMEOUT
value: {{ .Values.gunicorn.gracefulTimeout | quote }}
- name: GUNICORN_GRACEFUL_TIMEOUT
value: {{ .Values.gunicorn.gracefulTimeout | quote }}
- name: GUNICORN_KEEP_ALIVE
value: {{ .Values.gunicorn.keepAlive | quote }}
- name: GUNICORN_WORKER_CLASS
value: {{ .Values.gunicorn.workerClass | quote }}
{{- end }}
{{- end }}

{{- define "modelEngine.serviceEnvGitTagFromHelmVar" }}
Expand Down
13 changes: 13 additions & 0 deletions model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ def cache_redis_url(self) -> str:
creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role
return creds["cache-url"]

# Check if we're in an onprem environment with direct Redis access via config
if infra_config().cloud_provider == "onprem" and infra_config().redis_host:
# Onprem Redis configuration
redis_host = infra_config().redis_host
redis_port = infra_config().redis_port
redis_password = infra_config().redis_password

if redis_password:
return f"redis://:{redis_password}@{redis_host}:{redis_port}/0"
else:
return f"redis://{redis_host}:{redis_port}/0"

# Azure Redis configuration (existing logic)
assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure"
username = os.getenv("AZURE_OBJECT_ID")
token = DefaultAzureCredential().get_token("https://redis.azure.com/.default")
Expand Down
8 changes: 7 additions & 1 deletion model-engine/model_engine_server/core/aws/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from boto3 import Session, client
from botocore.client import BaseClient
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.core.config import infra_config

logger = make_logger(logger_name())

Expand Down Expand Up @@ -114,12 +115,17 @@ def assume_role(role_arn: str, role_session_name: Optional[str] = None) -> AwsCr
)


def session(role: Optional[str], session_type: SessionT = Session) -> SessionT:
def session(role: Optional[str], session_type: SessionT = Session) -> Optional[SessionT]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't need to touch this; this should only be used if cloud_provider == 'aws'

"""Obtain an AWS session using an arbitrary caller-specified role.

:param:`session_type` defines the type of session to return. Most users will use
the default boto3 type. Some users required a special type (e.g aioboto3 session).
"""
# Only create AWS sessions for AWS cloud provider
if infra_config().cloud_provider != "aws":
logger.warning(f"Not using AWS - cloud provider is {infra_config().cloud_provider} (ignoring: {role})")
return None

# Do not assume roles in CIRCLECI
if os.getenv("CIRCLECI"):
logger.warning(f"In circleci, not assuming role (ignoring: {role})")
Expand Down
17 changes: 8 additions & 9 deletions model-engine/model_engine_server/core/aws/secrets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""AWS secrets module."""

import json
import os
from functools import lru_cache
from typing import Optional

Expand All @@ -14,17 +15,15 @@

@lru_cache(maxsize=2)
def get_key_file(secret_name: str, aws_profile: Optional[str] = None):
# Only use AWS Secrets Manager for AWS cloud provider
if infra_config().cloud_provider != "aws":
logger.warning(f"Not using AWS Secrets Manager - cloud provider is {infra_config().cloud_provider} (cannot retrieve secret: {secret_name})")
return {}

if aws_profile is not None:
session = boto3.Session(profile_name=aws_profile)
secret_manager = session.client("secretsmanager", region_name=infra_config().default_region)
else:
secret_manager = boto3.client("secretsmanager", region_name=infra_config().default_region)
try:
secret_value = json.loads(
secret_manager.get_secret_value(SecretId=secret_name)["SecretString"]
)
return secret_value
except ClientError as e:
logger.error(e)
logger.error(f"Failed to retrieve secret: {secret_name}")
return {}
response = secret_manager.get_secret_value(SecretId=secret_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still want to do the try_catch wrapping to handle the cases where secret_manager client errors our

return json.loads(response["SecretString"])
9 changes: 9 additions & 0 deletions model-engine/model_engine_server/core/celery/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,20 @@ def _get_backend_url_and_conf(
# use db_num=1 for backend to differentiate from broker
backend_url = get_redis_endpoint(1)
elif backend_protocol == "s3":
# Only use S3 backend for AWS cloud provider
if infra_config().cloud_provider != "aws":
raise ValueError(f"S3 backend requires AWS cloud provider, but current provider is {infra_config().cloud_provider}")

backend_url = "s3://"
if aws_role is None:
aws_session = session(infra_config().profile_ml_worker)
else:
aws_session = session(aws_role)

# If AWS session creation fails, throw an error
if aws_session is None:
raise ValueError("Failed to create AWS session for S3 backend")

out_conf_changes.update(
{
"s3_boto3_session": aws_session,
Expand Down
9 changes: 9 additions & 0 deletions model-engine/model_engine_server/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,22 @@ class _InfraConfig:
docker_repo_prefix: str
s3_bucket: str
redis_host: Optional[str] = None
redis_port: Optional[str] = "6379"
redis_password: Optional[str] = None
redis_aws_secret_name: Optional[str] = None
profile_ml_worker: str = "default"
profile_ml_inference_worker: str = "default"
identity_service_url: Optional[str] = None
firehose_role_arn: Optional[str] = None
firehose_stream_name: Optional[str] = None
prometheus_server_address: Optional[str] = None
# AWS disable configuration
disable_aws: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we instead add a new cloud_provider == onprem and change logic based off that?

disable_aws_secrets_manager: bool = False
# Celery broker configuration
disable_sqs_broker: bool = False
disable_servicebus_broker: bool = False
force_celery_redis: bool = False


@dataclass
Expand Down
9 changes: 9 additions & 0 deletions model-engine/model_engine_server/core/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ default_region: "us-west-2"
ml_account_id: "000000000000"
docker_repo_prefix: "000000000000.dkr.ecr.us-west-2.amazonaws.com"
redis_host: "redis-message-broker-master.default"
redis_port: "6379"
redis_password: null
s3_bucket: "test-bucket"
profile_ml_worker: "default"
profile_ml_inference_worker: "default"
Expand All @@ -14,3 +16,10 @@ db_engine_max_overflow: 10
db_engine_echo: false
db_engine_echo_pool: false
db_engine_disconnect_strategy: "pessimistic"
# AWS disable configuration
disable_aws: false
disable_aws_secrets_manager: false
# Celery broker configuration
disable_sqs_broker: false
disable_servicebus_broker: false
force_celery_redis: false
21 changes: 18 additions & 3 deletions model-engine/model_engine_server/core/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,24 @@ def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Lo
raise ValueError("Name must be a non-empty string.")
logger = logging.getLogger(name)
logger.setLevel(log_level)
logging.basicConfig(
format=LOG_FORMAT,
)

# Thread-safe logging configuration - only configure if not already configured
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, was this manifesting in a particular error?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there was a specific recursive logging error causing worker crashes:

RuntimeError: reentrant call inside <_io.BufferedWriter name=''>

This occurred during Gunicorn worker startup when multiple processes tried to initialize logging simultaneously, causing thread-unsafe logging configuration and race conditions. The error led to worker crashes, which then triggered the WORKER TIMEOUT errors we were seeing.

The issue was that multiple Gunicorn workers starting at the same time would compete to write to stderr during logging setup, causing a reentrant call error that crashed the worker processes.

if not logger.handlers:
# Use a lock to prevent race conditions in multi-threaded environments
import threading
with threading.Lock():
if not logger.handlers: # Double-check after acquiring lock
# Configure basic logging only if not already configured
if not logging.getLogger().handlers:
logging.basicConfig(
format=LOG_FORMAT,
)
# Add handler to this specific logger if needed
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(LOG_FORMAT))
logger.addHandler(handler)

return logger


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import argparse
import os
import subprocess
from typing import List

Expand All @@ -14,18 +15,25 @@ def start_gunicorn_server(port: int, num_workers: int, debug: bool) -> None:
additional_args: List[str] = []
if debug:
additional_args.extend(["--reload", "--timeout", "0"])

# Use environment variables for configuration with fallbacks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

timeout = int(os.environ.get('WORKER_TIMEOUT', os.environ.get('GUNICORN_TIMEOUT', 60)))
graceful_timeout = int(os.environ.get('GUNICORN_GRACEFUL_TIMEOUT', timeout))
keep_alive = int(os.environ.get('GUNICORN_KEEP_ALIVE', 2))
worker_class = os.environ.get('GUNICORN_WORKER_CLASS', 'model_engine_server.api.worker.LaunchWorker')

command = [
"gunicorn",
"--bind",
f"[::]:{port}",
"--timeout",
"60",
str(timeout),
"--graceful-timeout",
"60",
str(graceful_timeout),
"--keep-alive",
"2",
str(keep_alive),
"--worker-class",
"model_engine_server.api.worker.LaunchWorker",
worker_class,
"--workers",
f"{num_workers}",
*additional_args,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
from typing import Any, Dict, List, Optional

import botocore
Expand All @@ -18,28 +19,59 @@
logger = make_logger(logger_name())
backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3"

celery_redis = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.REDIS.value),
backend_protocol=backend_protocol,
)
celery_redis_24h = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.REDIS.value),
task_visibility=TaskVisibility.VISIBILITY_24H,
backend_protocol=backend_protocol,
)
celery_sqs = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.SQS.value),
backend_protocol=backend_protocol,
)
celery_servicebus = celery_app(
None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol
)
# Initialize celery apps lazily to avoid import-time AWS session creation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why we're not running into similar issues for our other non-AWS environments

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Prem - This is likely due to no creds and the import failing
Container starts with NO AWS credentials
Python import hits celery_task_queue_gateway.py:19
boto3.Session() creation fails immediately
Import exception → container crash before application even starts

celery_redis = None
celery_redis_24h = None
celery_sqs = None
celery_servicebus = None

def _get_celery_redis():
global celery_redis
if celery_redis is None:
celery_redis = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.REDIS.value),
backend_protocol=backend_protocol,
)
return celery_redis

def _get_celery_redis_24h():
global celery_redis_24h
if celery_redis_24h is None:
celery_redis_24h = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.REDIS.value),
task_visibility=TaskVisibility.VISIBILITY_24H,
backend_protocol=backend_protocol,
)
return celery_redis_24h

def _get_celery_sqs():
global celery_sqs
if celery_sqs is None:
# Check if SQS broker is disabled via cloud provider
if infra_config().cloud_provider != "aws":
raise ValueError(f"SQS broker requires AWS cloud provider, but current provider is {infra_config().cloud_provider}")
celery_sqs = celery_app(
None,
s3_bucket=infra_config().s3_bucket,
broker_type=str(BrokerType.SQS.value),
backend_protocol=backend_protocol,
)
return celery_sqs

def _get_celery_servicebus():
global celery_servicebus
if celery_servicebus is None:
# Check if ServiceBus broker is disabled via cloud provider
if infra_config().cloud_provider != "azure":
raise ValueError(f"ServiceBus broker requires Azure cloud provider, but current provider is {infra_config().cloud_provider}")
celery_servicebus = celery_app(
None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol
)
return celery_servicebus


class CeleryTaskQueueGateway(TaskQueueGateway):
Expand All @@ -55,13 +87,13 @@ def __init__(self, broker_type: BrokerType, tracing_gateway: TracingGateway):

def _get_celery_dest(self):
if self.broker_type == BrokerType.SQS:
return celery_sqs
return _get_celery_sqs()
elif self.broker_type == BrokerType.REDIS_24H:
return celery_redis_24h
return _get_celery_redis_24h()
elif self.broker_type == BrokerType.REDIS:
return celery_redis
return _get_celery_redis()
else:
return celery_servicebus
return _get_celery_servicebus()

def send_task(
self,
Expand Down