Skip to content

Commit 7da087e

Browse files
authored
Add AWS creds config to object_store (#98)
Signed-off-by: Rajvaibhav Rahane <rrahane@amazon.com>
1 parent c39ed73 commit 7da087e

File tree

8 files changed

+224
-50
lines changed

8 files changed

+224
-50
lines changed

e2e/api/vector_dataset_generator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from botocore.exceptions import ClientError
1313
from core.common.models import IndexBuildParameters
1414
from core.object_store.object_store_factory import ObjectStoreFactory
15+
from core.object_store.s3.s3_object_store_config import S3ClientConfig
1516
from core.object_store.types import ObjectStoreType
1617
import logging
1718
from tqdm import tqdm
@@ -38,11 +39,11 @@ def initialize_object_store(self):
3839
doc_count=5, # Will be set per dataset
3940
)
4041
object_store_config = {
41-
"retries": s3_config["retries"],
42-
"region": s3_config["region"],
43-
"S3_ENDPOINT_URL": os.environ.get(
44-
"S3_ENDPOINT_URL", "http://localhost:4566"
45-
),
42+
"s3_client_config": S3ClientConfig(
43+
region_name=s3_config["region"],
44+
max_retries=s3_config["retries"],
45+
endpoint_url=os.environ.get("S3_ENDPOINT_URL", "http://localhost:4566"),
46+
)
4647
}
4748
return ObjectStoreFactory.create_object_store(
4849
index_build_params, object_store_config

remote_vector_index_builder/app/services/index_builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
from typing import Optional, Tuple
1111
from app.models.workflow import BuildWorkflow
12+
from core.object_store.s3.s3_object_store_config import S3ClientConfig
1213
from core.tasks import run_tasks
1314

1415
logger = logging.getLogger(__name__)
@@ -36,7 +37,13 @@ def build_index(
3637
"""
3738
s3_endpoint_url = os.environ.get("S3_ENDPOINT_URL", None)
3839
result = run_tasks(
39-
workflow.index_build_parameters, {"S3_ENDPOINT_URL": s3_endpoint_url}
40+
workflow.index_build_parameters,
41+
{
42+
"s3_client_config": S3ClientConfig(
43+
region_name=os.environ.get("AWS_DEFAULT_REGION", None),
44+
endpoint_url=s3_endpoint_url,
45+
),
46+
},
4047
)
4148
if not result.file_name:
4249
return False, None, result.error

remote_vector_index_builder/core/object_store/s3/s3_object_store.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import threading
1313
from functools import cache
1414
from io import BytesIO
15-
from typing import Any, Dict, Optional
15+
from typing import Any, Dict
1616

1717
import boto3
1818
from boto3.s3.transfer import TransferConfig
@@ -21,6 +21,7 @@
2121
from core.common.exceptions import BlobError
2222
from core.common.models import IndexBuildParameters
2323
from core.object_store.object_store import ObjectStore
24+
from core.object_store.s3.s3_object_store_config import S3ClientConfig
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -46,28 +47,24 @@ def get_cpus(factor: float) -> int:
4647

4748

4849
@cache
49-
def get_boto3_client(
50-
region: str, retries: int, endpoint_url: Optional[str] = None
51-
) -> boto3.client:
50+
def get_boto3_client(s3_client_config: S3ClientConfig) -> boto3.client:
5251
"""Create or retrieve a cached boto3 S3 client.
5352
5453
Args:
55-
region (str): AWS region name for the S3 client
56-
retries (int): Maximum number of retry attempts for failed requests
57-
endpoint_url (str): s3 endpoint URL. Defaults to None, in which case boto3
58-
automatically constructs the appropriate URL to use when communicating
59-
with a service. During integration testing, this can be set to the endpoint URL
60-
for LocalStack S3 service.
54+
s3_client_config (S3ClientConfig): Configuration class for creating S3 Boto3 client
6155
6256
Returns:
6357
boto3.client: Configured S3 client instance
6458
"""
65-
config = Config(retries={"max_attempts": retries})
59+
config = Config(retries={"max_attempts": s3_client_config.max_retries})
6660
return boto3.client(
6761
"s3",
6862
config=config,
69-
region_name=region,
70-
endpoint_url=endpoint_url,
63+
region_name=s3_client_config.region_name,
64+
endpoint_url=s3_client_config.endpoint_url,
65+
aws_access_key_id=s3_client_config.aws_access_key_id,
66+
aws_secret_access_key=s3_client_config.aws_secret_access_key,
67+
aws_session_token=s3_client_config.aws_session_token,
7168
)
7269

7370

@@ -102,11 +99,27 @@ def __init__(
10299
Args:
103100
index_build_params (IndexBuildParameters): Contains bucket name and other
104101
index building parameters
102+
105103
object_store_config (Dict[str, Any]): Configuration dictionary containing:
106-
- retries (int): Maximum number of retry attempts (default: 3)
107-
- region (str): AWS region name (default: 'us-west-2')
104+
Contains:
108105
- transfer_config (Dict[str, Any]): s3 TransferConfig parameters
109106
- debug: Turns on debug mode (default: False)
107+
- s3_client_config (S3ClientConfig) (Required):
108+
Required:
109+
- region_name (str) (required): AWS Region name
110+
Optional:
111+
- endpoint_url (Optional[str]): Custom S3 endpoint URL
112+
- max_retries (int) (default: 3): Maximum number of retry attempts for failed requests
113+
114+
AWS Credentials (all optional):
115+
- aws_access_key_id (Optional[str]): AWS Access Key ID
116+
- aws_secret_access_key (Optional[str]): AWS Secret Access Key
117+
- aws_session_token (Optional[str]): Temporary session token for STS credentials
118+
119+
Note:
120+
AWS credentials are optional as boto3 will attempt to find credentials:
121+
For more details see boto3 client documentation:
122+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
110123
"""
111124

112125
self.DEFAULT_DOWNLOAD_TRANSFER_CONFIG = {
@@ -130,14 +143,12 @@ def __init__(
130143
"ChecksumAlgorithm": "CRC32",
131144
}
132145
self.bucket = index_build_params.container_name
133-
self.max_retries = object_store_config.get("retries", 3)
134-
self.region = object_store_config.get("region", "us-west-2")
135146

136-
self.s3_client = get_boto3_client(
137-
region=self.region,
138-
retries=self.max_retries,
139-
endpoint_url=object_store_config.get("S3_ENDPOINT_URL"),
140-
)
147+
s3_client_config: S3ClientConfig = object_store_config.get("s3_client_config")
148+
self.max_retries = s3_client_config.max_retries
149+
self.region = s3_client_config.region_name
150+
151+
self.s3_client = get_boto3_client(s3_client_config)
141152

142153
download_transfer_config = object_store_config.get(
143154
"download_transfer_config", {}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel
4+
5+
6+
class S3ClientConfig(BaseModel):
7+
"""
8+
Configuration class for creating S3 Boto3 client
9+
10+
Initialize the S3ObjectStore boto3 client with the following parameters:
11+
12+
Attributes:
13+
region_name (str) (required): AWS Region name
14+
endpoint_url (Optional[str]): Custom S3 endpoint URL
15+
max_retries (int) (default: 3): Maximum number of retry attempts for failed requests
16+
17+
AWS Credentials (all optional):
18+
aws_access_key_id (Optional[str]): AWS Access Key ID
19+
aws_secret_access_key (Optional[str]): AWS Secret Access Key
20+
aws_session_token (Optional[str]): Temporary session token for STS credentials
21+
22+
Note:
23+
AWS credentials are optional as boto3 will attempt to find credentials:
24+
For more details see boto3 client documentation:
25+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
26+
"""
27+
28+
region_name: str
29+
endpoint_url: Optional[str] = None
30+
max_retries: int = 3
31+
32+
# AWS Credentials parameters
33+
# Ref: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
34+
aws_access_key_id: Optional[str] = None
35+
aws_secret_access_key: Optional[str] = None
36+
aws_session_token: Optional[str] = None
37+
38+
def __hash__(self):
39+
"""
40+
Generate a hash value for this configuration.
41+
42+
Required for @functools.cache to use this class as a dictionary key.
43+
All attributes that affect the client configuration must be included
44+
in the hash tuple to ensure proper cache behavior.
45+
46+
Returns:
47+
int: Hash value based on all configuration attributes
48+
"""
49+
return hash(
50+
(
51+
self.region_name,
52+
self.max_retries,
53+
self.endpoint_url,
54+
self.aws_access_key_id,
55+
self.aws_secret_access_key,
56+
self.aws_session_token,
57+
)
58+
)
59+
60+
def __eq__(self, other):
61+
"""
62+
Compare this configuration with another for equality.
63+
64+
Required for @functools.cache to properly identify cache hits.
65+
Two configurations are equal if all their attributes are equal.
66+
67+
Args:
68+
other: Another object to compare with this configuration
69+
70+
Returns:
71+
bool: True if other is an S3ClientConfig with identical attributes,
72+
False otherwise
73+
74+
Note:
75+
Returns NotImplemented for non-S3ClientConfig objects to allow
76+
Python to try other comparison methods.
77+
"""
78+
if not isinstance(other, S3ClientConfig):
79+
return NotImplemented
80+
return (
81+
self.region_name == other.region_name
82+
and self.max_retries == other.max_retries
83+
and self.endpoint_url == other.endpoint_url
84+
and self.aws_access_key_id == other.aws_access_key_id
85+
and self.aws_secret_access_key == other.aws_secret_access_key
86+
and self.aws_session_token == other.aws_session_token
87+
)

test_remote_vector_index_builder/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# compatible open source license.
77

88
import pytest
9+
import os
910
from core.common.models.index_build_parameters import (
1011
AlgorithmParameters,
1112
IndexBuildParameters,
@@ -31,3 +32,11 @@ def index_build_parameters():
3132
),
3233
repository_type="s3",
3334
)
35+
36+
37+
@pytest.fixture(autouse=True)
38+
def aws_credentials():
39+
"""Mocked AWS Credentials for tests."""
40+
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"
41+
yield
42+
os.environ.pop("AWS_DEFAULT_REGION", None)

test_remote_vector_index_builder/test_core/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from unittest.mock import Mock
1313

1414
from core.common.models import VectorsDataset
15+
from core.object_store.s3.s3_object_store_config import S3ClientConfig
1516

1617

1718
class DeletionTracker:
@@ -197,7 +198,7 @@ def _write_index(self, index, filepath):
197198
def object_store_config():
198199
"""Create a sample object store configuration for testing"""
199200
return {
200-
"region": "us-west-2",
201+
"s3_client_config": S3ClientConfig(region_name="us-west-2", max_retries="4")
201202
}
202203

203204

0 commit comments

Comments
 (0)