Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions model-engine/example_sleep_model_deployment/service_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
bundle_config:
model_bundle_name: sleep-model-timeout-test
request_schema: Dict[str, Any]
response_schema: Dict[str, Any]
repository: sleep_model
tag: latest
command:
- python
- app.py
readiness_initial_delay_seconds: 30

endpoint_config:
endpoint_name: sleep-model-timeout-test
model_bundle: sleep-model-timeout-test
cpus: 1
memory: 2Gi
storage: 10Gi
gpus: 0
min_workers: 1
max_workers: 1
per_worker: 1
endpoint_type: async
queue_message_timeout_duration: 90 # 90 seconds to handle 70s inference + buffer
labels:
team: test
product: sleep-model-timeout-test
update_if_exists: True
16 changes: 16 additions & 0 deletions model-engine/example_sleep_model_deployment/sleep_model/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
FROM python:3.9-slim

WORKDIR /app

# Copy requirements and install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy application code
COPY app.py .

# Expose port
EXPOSE 8080

# Run the application
CMD ["gunicorn", "--bind", "0.0.0.0:8080", "--timeout", "120", "app:app"]
50 changes: 50 additions & 0 deletions model-engine/example_sleep_model_deployment/sleep_model/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Simple sleep model for testing queue timeout duration.
This model sleeps for 70 seconds to test queue lock duration > 60 seconds.
"""

import time
from typing import Any, Dict
from flask import Flask, request, jsonify

app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
"""
Prediction endpoint that sleeps for 70 seconds to test queue timeout.
"""
try:
data = request.get_json()
sleep_duration = data.get('sleep_duration', 70) # Default 70 seconds

print(f"Starting inference... will sleep for {sleep_duration} seconds")

# Sleep to simulate long-running inference
time.sleep(sleep_duration)

response = {
"result": f"Completed after sleeping for {sleep_duration} seconds",
"input": data,
"status": "success"
}

print(f"Inference completed successfully after {sleep_duration} seconds")
return jsonify(response)

except Exception as e:
print(f"Error during inference: {e}")
return jsonify({"error": str(e), "status": "failed"}), 500

@app.route('/health', methods=['GET'])
def health():
"""Health check endpoint"""
return jsonify({"status": "healthy"})

@app.route('/readyz', methods=['GET'])
def ready():
"""Readiness check endpoint"""
return jsonify({"status": "ready"})

if __name__ == '__main__':
app.run(host='0.0.0.0', port=8080)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Flask==2.3.3
gunicorn==21.2.0
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class BuildEndpointRequest(BaseModel):
high_priority: Optional[bool] = None
default_callback_url: Optional[str] = None
default_callback_auth: Optional[CallbackAuth] = None
queue_message_timeout_duration: Optional[int] = None


class BuildEndpointStatus(str, Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class CreateModelEndpointV1Request(BaseModel):
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = Field(default=False)
queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1)


class CreateModelEndpointV1Response(BaseModel):
Expand Down Expand Up @@ -100,6 +101,7 @@ class UpdateModelEndpointV1Request(BaseModel):
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = None
queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1)


class UpdateModelEndpointV1Response(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,7 @@ async def execute(
default_callback_url=request.default_callback_url,
default_callback_auth=request.default_callback_auth,
public_inference=request.public_inference,
queue_message_timeout_duration=request.queue_message_timeout_duration,
)
_handle_post_inference_hooks(
created_by=user.user_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ async def execute(
default_callback_url=request.default_callback_url,
default_callback_auth=request.default_callback_auth,
public_inference=request.public_inference,
queue_message_timeout_duration=request.queue_message_timeout_duration,
)
_handle_post_inference_hooks(
created_by=user.user_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def create_model_endpoint_infra(
billing_tags: Optional[Dict[str, Any]] = None,
default_callback_url: Optional[str],
default_callback_auth: Optional[CallbackAuth],
queue_message_timeout_duration: Optional[int] = None,
) -> str:
deployment_name = generate_deployment_name(
model_endpoint_record.created_by, model_endpoint_record.name
Expand Down Expand Up @@ -104,6 +105,7 @@ def create_model_endpoint_infra(
billing_tags=billing_tags,
default_callback_url=default_callback_url,
default_callback_auth=default_callback_auth,
queue_message_timeout_duration=queue_message_timeout_duration,
)
response = self.task_queue_gateway.send_task(
task_name=BUILD_TASK_NAME,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Any, Dict
from datetime import timedelta
from typing import Any, Dict, Optional

from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.servicebus.management import ServiceBusAdministrationClient
from azure.servicebus.management import ServiceBusAdministrationClient, QueueProperties
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import EndpointResourceInfraException
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
Expand Down Expand Up @@ -32,13 +33,36 @@ async def create_queue_if_not_exists(
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
timeout_duration = queue_message_timeout_duration or 60 # Default to 60 seconds

# Validation: Azure Service Bus lock duration must be <= 5 minutes (300s)
if timeout_duration > 300:
raise ValueError(f"queue_message_timeout_duration ({timeout_duration}s) exceeds Azure Service Bus maximum of 300 seconds")

with _get_servicebus_administration_client() as client:
try:
# First, try to create the queue with default properties
client.create_queue(queue_name=queue_name)

# Then update the queue properties to set custom lock duration
queue_properties = client.get_queue(queue_name)
queue_properties.lock_duration = timedelta(seconds=timeout_duration)
client.update_queue(queue_properties)

except ResourceExistsError:
pass
# Queue already exists, update its properties if needed
try:
queue_properties = client.get_queue(queue_name)
# Only update if the lock duration is different
if queue_properties.lock_duration != timedelta(seconds=timeout_duration):
queue_properties.lock_duration = timedelta(seconds=timeout_duration)
client.update_queue(queue_properties)
except Exception as e:
# If we can't update properties, log but don't fail
logger.warning(f"Could not update queue properties for {queue_name}: {e}")

return QueueInfo(queue_name, None)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Sequence
from typing import Any, Dict, Optional, Sequence

from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
Expand All @@ -15,6 +15,7 @@ async def create_queue_if_not_exists(
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
queue_url = f"http://foobar.com/{queue_name}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ async def create_queue(
self,
endpoint_record: ModelEndpointRecord,
labels: Dict[str, str],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
"""Creates a new queue, returning its unique name and queue URL."""
queue_name, queue_url = await self.queue_delegate.create_queue_if_not_exists(
endpoint_id=endpoint_record.id,
endpoint_name=endpoint_record.name,
endpoint_created_by=endpoint_record.created_by,
endpoint_labels=labels,
queue_message_timeout_duration=queue_message_timeout_duration,
)
return QueueInfo(queue_name, queue_url)

Expand All @@ -56,7 +58,11 @@ async def create_or_update_resources(
request.build_endpoint_request.model_endpoint_record.endpoint_type
== ModelEndpointType.ASYNC
):
q = await self.create_queue(endpoint_record, request.build_endpoint_request.labels)
q = await self.create_queue(
endpoint_record,
request.build_endpoint_request.labels,
request.build_endpoint_request.queue_message_timeout_duration
)
queue_name: Optional[str] = q.queue_name
queue_url: Optional[str] = q.queue_url
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ async def create_queue_if_not_exists(
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
"""
Creates a queue associated with the given endpoint_id. Other fields are set as tags on the queue.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ async def create_queue_if_not_exists(
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
timeout_duration = queue_message_timeout_duration or 60 # Default to 60 seconds

async with _create_async_sqs_client(sqs_profile=self.sqs_profile) as sqs_client:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)

Expand All @@ -73,9 +76,7 @@ async def create_queue_if_not_exists(
create_response = await sqs_client.create_queue(
QueueName=queue_name,
Attributes=dict(
VisibilityTimeout="43200",
# To match current hardcoded Celery timeout of 24hr
# However, the max SQS visibility is 12hrs.
VisibilityTimeout=str(timeout_duration),
Policy=_get_queue_policy(queue_name=queue_name),
),
tags=_get_queue_tags(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ async def create_batch_job(
owner=owner,
default_callback_url=None,
default_callback_auth=None,
queue_message_timeout_duration=None,
)

await self.batch_job_record_repository.update_batch_job_record(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def get_base_image_params(
return BuildImageRequest(
repo=hmi_config.user_inference_base_repository,
image_tag=resulting_image_tag[:MAX_IMAGE_TAG_LEN],
aws_profile=ECR_AWS_PROFILE, # type: ignore
aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE
base_path=WORKSPACE_PATH,
dockerfile=f"{inference_folder}/{dockerfile}",
base_image=base_image,
Expand Down Expand Up @@ -577,7 +577,7 @@ def _get_user_image_params(
return BuildImageRequest(
repo=ecr_repo,
image_tag=service_image_tag[:MAX_IMAGE_TAG_LEN],
aws_profile=ECR_AWS_PROFILE,
aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE
base_path=WORKSPACE_PATH,
dockerfile=f"{inference_folder}/{dockerfile}",
base_image=base_image,
Expand Down Expand Up @@ -633,7 +633,7 @@ def _get_inject_bundle_image_params(
return BuildImageRequest(
repo=ecr_repo,
image_tag=service_image_tag[:MAX_IMAGE_TAG_LEN],
aws_profile=ECR_AWS_PROFILE,
aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE
base_path=WORKSPACE_PATH,
dockerfile=f"{inference_folder}/{dockerfile}",
base_image=base_image,
Expand Down Expand Up @@ -667,7 +667,7 @@ async def _build_image(
if not self.docker_repository.image_exists(
repository_name=image_params.repo,
image_tag=image_params.image_tag,
aws_profile=ECR_AWS_PROFILE,
aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE
):
self.monitoring_metrics_gateway.emit_image_build_cache_miss_metric(image_type)
tags = [
Expand Down Expand Up @@ -713,7 +713,7 @@ async def _build_image(
if not build_result_status and not self.docker_repository.image_exists(
repository_name=image_params.repo,
image_tag=image_params.image_tag,
aws_profile=ECR_AWS_PROFILE,
aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE
):
log_error(
f"Image build failed for endpoint {model_endpoint_name}, user {user_id}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ async def create_model_endpoint(
default_callback_url: Optional[str] = None,
default_callback_auth: Optional[CallbackAuth],
public_inference: Optional[bool] = False,
queue_message_timeout_duration: Optional[int] = None,
) -> ModelEndpointRecord:
existing_endpoints = (
await self.model_endpoint_record_repository.list_model_endpoint_records(
Expand Down Expand Up @@ -209,6 +210,7 @@ async def create_model_endpoint(
high_priority=high_priority,
default_callback_url=default_callback_url,
default_callback_auth=default_callback_auth,
queue_message_timeout_duration=queue_message_timeout_duration,
)
await self.model_endpoint_record_repository.update_model_endpoint_record(
model_endpoint_id=model_endpoint_record.id,
Expand Down
Loading