-
Notifications
You must be signed in to change notification settings - Fork 67
on prem changes to disable cloud solutions #700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
16bea80
d72f6c7
4442dd8
3febc44
a1fe268
61a67ae
10f9e4f
0dcd54a
de8cb1a
324fe4d
b138a2e
1ee453a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
from boto3 import Session, client | ||
from botocore.client import BaseClient | ||
from model_engine_server.core.loggers import logger_name, make_logger | ||
from model_engine_server.core.config import infra_config | ||
|
||
logger = make_logger(logger_name()) | ||
|
||
|
@@ -114,12 +115,17 @@ def assume_role(role_arn: str, role_session_name: Optional[str] = None) -> AwsCr | |
) | ||
|
||
|
||
def session(role: Optional[str], session_type: SessionT = Session) -> SessionT: | ||
def session(role: Optional[str], session_type: SessionT = Session) -> Optional[SessionT]: | ||
|
||
"""Obtain an AWS session using an arbitrary caller-specified role. | ||
|
||
:param:`session_type` defines the type of session to return. Most users will use | ||
the default boto3 type. Some users required a special type (e.g aioboto3 session). | ||
""" | ||
# Only create AWS sessions for AWS cloud provider | ||
if infra_config().cloud_provider != "aws": | ||
logger.warning(f"Not using AWS - cloud provider is {infra_config().cloud_provider} (ignoring: {role})") | ||
return None | ||
|
||
# Do not assume roles in CIRCLECI | ||
if os.getenv("CIRCLECI"): | ||
logger.warning(f"In circleci, not assuming role (ignoring: {role})") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""AWS secrets module.""" | ||
|
||
import json | ||
import os | ||
from functools import lru_cache | ||
from typing import Optional | ||
|
||
|
@@ -14,17 +15,15 @@ | |
|
||
@lru_cache(maxsize=2) | ||
def get_key_file(secret_name: str, aws_profile: Optional[str] = None): | ||
# Only use AWS Secrets Manager for AWS cloud provider | ||
if infra_config().cloud_provider != "aws": | ||
logger.warning(f"Not using AWS Secrets Manager - cloud provider is {infra_config().cloud_provider} (cannot retrieve secret: {secret_name})") | ||
return {} | ||
|
||
if aws_profile is not None: | ||
session = boto3.Session(profile_name=aws_profile) | ||
secret_manager = session.client("secretsmanager", region_name=infra_config().default_region) | ||
else: | ||
secret_manager = boto3.client("secretsmanager", region_name=infra_config().default_region) | ||
try: | ||
secret_value = json.loads( | ||
secret_manager.get_secret_value(SecretId=secret_name)["SecretString"] | ||
) | ||
return secret_value | ||
except ClientError as e: | ||
logger.error(e) | ||
logger.error(f"Failed to retrieve secret: {secret_name}") | ||
return {} | ||
response = secret_manager.get_secret_value(SecretId=secret_name) | ||
|
||
return json.loads(response["SecretString"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,13 +42,22 @@ class _InfraConfig: | |
docker_repo_prefix: str | ||
s3_bucket: str | ||
redis_host: Optional[str] = None | ||
redis_port: Optional[str] = "6379" | ||
redis_password: Optional[str] = None | ||
redis_aws_secret_name: Optional[str] = None | ||
profile_ml_worker: str = "default" | ||
profile_ml_inference_worker: str = "default" | ||
identity_service_url: Optional[str] = None | ||
firehose_role_arn: Optional[str] = None | ||
firehose_stream_name: Optional[str] = None | ||
prometheus_server_address: Optional[str] = None | ||
# AWS disable configuration | ||
disable_aws: bool = False | ||
|
||
disable_aws_secrets_manager: bool = False | ||
# Celery broker configuration | ||
disable_sqs_broker: bool = False | ||
disable_servicebus_broker: bool = False | ||
force_celery_redis: bool = False | ||
|
||
|
||
@dataclass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,9 +80,24 @@ def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Lo | |
raise ValueError("Name must be a non-empty string.") | ||
logger = logging.getLogger(name) | ||
logger.setLevel(log_level) | ||
logging.basicConfig( | ||
format=LOG_FORMAT, | ||
) | ||
|
||
# Thread-safe logging configuration - only configure if not already configured | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. interesting, was this manifesting in a particular error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there was a specific recursive logging error causing worker crashes: RuntimeError: reentrant call inside <_io.BufferedWriter name=''> This occurred during Gunicorn worker startup when multiple processes tried to initialize logging simultaneously, causing thread-unsafe logging configuration and race conditions. The error led to worker crashes, which then triggered the WORKER TIMEOUT errors we were seeing. The issue was that multiple Gunicorn workers starting at the same time would compete to write to stderr during logging setup, causing a reentrant call error that crashed the worker processes. |
||
if not logger.handlers: | ||
# Use a lock to prevent race conditions in multi-threaded environments | ||
import threading | ||
with threading.Lock(): | ||
if not logger.handlers: # Double-check after acquiring lock | ||
# Configure basic logging only if not already configured | ||
if not logging.getLogger().handlers: | ||
logging.basicConfig( | ||
format=LOG_FORMAT, | ||
) | ||
# Add handler to this specific logger if needed | ||
if not logger.handlers: | ||
handler = logging.StreamHandler() | ||
handler.setFormatter(logging.Formatter(LOG_FORMAT)) | ||
logger.addHandler(handler) | ||
|
||
return logger | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
""" | ||
|
||
import argparse | ||
import os | ||
import subprocess | ||
from typing import List | ||
|
||
|
@@ -14,18 +15,25 @@ def start_gunicorn_server(port: int, num_workers: int, debug: bool) -> None: | |
additional_args: List[str] = [] | ||
if debug: | ||
additional_args.extend(["--reload", "--timeout", "0"]) | ||
|
||
# Use environment variables for configuration with fallbacks | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
timeout = int(os.environ.get('WORKER_TIMEOUT', os.environ.get('GUNICORN_TIMEOUT', 60))) | ||
graceful_timeout = int(os.environ.get('GUNICORN_GRACEFUL_TIMEOUT', timeout)) | ||
keep_alive = int(os.environ.get('GUNICORN_KEEP_ALIVE', 2)) | ||
worker_class = os.environ.get('GUNICORN_WORKER_CLASS', 'model_engine_server.api.worker.LaunchWorker') | ||
|
||
command = [ | ||
"gunicorn", | ||
"--bind", | ||
f"[::]:{port}", | ||
"--timeout", | ||
"60", | ||
str(timeout), | ||
"--graceful-timeout", | ||
"60", | ||
str(graceful_timeout), | ||
"--keep-alive", | ||
"2", | ||
str(keep_alive), | ||
"--worker-class", | ||
"model_engine_server.api.worker.LaunchWorker", | ||
worker_class, | ||
"--workers", | ||
f"{num_workers}", | ||
*additional_args, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import json | ||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
import botocore | ||
|
@@ -18,28 +19,59 @@ | |
logger = make_logger(logger_name()) | ||
backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3" | ||
|
||
celery_redis = celery_app( | ||
None, | ||
s3_bucket=infra_config().s3_bucket, | ||
broker_type=str(BrokerType.REDIS.value), | ||
backend_protocol=backend_protocol, | ||
) | ||
celery_redis_24h = celery_app( | ||
None, | ||
s3_bucket=infra_config().s3_bucket, | ||
broker_type=str(BrokerType.REDIS.value), | ||
task_visibility=TaskVisibility.VISIBILITY_24H, | ||
backend_protocol=backend_protocol, | ||
) | ||
celery_sqs = celery_app( | ||
None, | ||
s3_bucket=infra_config().s3_bucket, | ||
broker_type=str(BrokerType.SQS.value), | ||
backend_protocol=backend_protocol, | ||
) | ||
celery_servicebus = celery_app( | ||
None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol | ||
) | ||
# Initialize celery apps lazily to avoid import-time AWS session creation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious why we're not running into similar issues for our other non-AWS environments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On Prem - This is likely due to no creds and the import failing |
||
celery_redis = None | ||
celery_redis_24h = None | ||
celery_sqs = None | ||
celery_servicebus = None | ||
|
||
def _get_celery_redis(): | ||
global celery_redis | ||
if celery_redis is None: | ||
celery_redis = celery_app( | ||
None, | ||
s3_bucket=infra_config().s3_bucket, | ||
broker_type=str(BrokerType.REDIS.value), | ||
backend_protocol=backend_protocol, | ||
) | ||
return celery_redis | ||
|
||
def _get_celery_redis_24h(): | ||
global celery_redis_24h | ||
if celery_redis_24h is None: | ||
celery_redis_24h = celery_app( | ||
None, | ||
s3_bucket=infra_config().s3_bucket, | ||
broker_type=str(BrokerType.REDIS.value), | ||
task_visibility=TaskVisibility.VISIBILITY_24H, | ||
backend_protocol=backend_protocol, | ||
) | ||
return celery_redis_24h | ||
|
||
def _get_celery_sqs(): | ||
global celery_sqs | ||
if celery_sqs is None: | ||
# Check if SQS broker is disabled via cloud provider | ||
if infra_config().cloud_provider != "aws": | ||
raise ValueError(f"SQS broker requires AWS cloud provider, but current provider is {infra_config().cloud_provider}") | ||
celery_sqs = celery_app( | ||
None, | ||
s3_bucket=infra_config().s3_bucket, | ||
broker_type=str(BrokerType.SQS.value), | ||
backend_protocol=backend_protocol, | ||
) | ||
return celery_sqs | ||
|
||
def _get_celery_servicebus(): | ||
global celery_servicebus | ||
if celery_servicebus is None: | ||
# Check if ServiceBus broker is disabled via cloud provider | ||
if infra_config().cloud_provider != "azure": | ||
raise ValueError(f"ServiceBus broker requires Azure cloud provider, but current provider is {infra_config().cloud_provider}") | ||
celery_servicebus = celery_app( | ||
None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol | ||
) | ||
return celery_servicebus | ||
|
||
|
||
class CeleryTaskQueueGateway(TaskQueueGateway): | ||
|
@@ -55,13 +87,13 @@ def __init__(self, broker_type: BrokerType, tracing_gateway: TracingGateway): | |
|
||
def _get_celery_dest(self): | ||
if self.broker_type == BrokerType.SQS: | ||
return celery_sqs | ||
return _get_celery_sqs() | ||
elif self.broker_type == BrokerType.REDIS_24H: | ||
return celery_redis_24h | ||
return _get_celery_redis_24h() | ||
elif self.broker_type == BrokerType.REDIS: | ||
return celery_redis | ||
return _get_celery_redis() | ||
else: | ||
return celery_servicebus | ||
return _get_celery_servicebus() | ||
|
||
def send_task( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
anecdotally, we found it a lot easier to performance tune pure uvicorn, so we actually migrated most usage of gunicorn back to uvicorn. That being said, won't block your usage of it