Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions engine/baml-lib/llm-client/src/clients/aws_bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub struct UnresolvedAwsBedrock<Meta> {
secret_access_key: Option<StringOr>,
session_token: Option<StringOr>,
profile: Option<StringOr>,
endpoint_url: Option<StringOr>,
role_selection: UnresolvedRolesSelection,
allowed_role_metadata: UnresolvedAllowedRoleMetadata,
supported_request_modes: SupportedRequestModes,
Expand Down Expand Up @@ -78,6 +79,7 @@ pub struct ResolvedAwsBedrock {
pub secret_access_key: Option<ApiKeyWithProvenance>,
pub session_token: Option<String>,
pub profile: Option<String>,
pub endpoint_url: Option<String>,
pub inference_config: Option<InferenceConfiguration>,
role_selection: RolesSelection,
pub allowed_role_metadata: AllowedRoleMetadata,
Expand All @@ -95,6 +97,7 @@ impl std::fmt::Debug for ResolvedAwsBedrock {
.field("secret_access_key", &"<no-repr-available>")
.field("session_token", &self.session_token)
.field("profile", &self.profile)
.field("endpoint_url", &self.endpoint_url)
.field("inference_config", &"<no-repr-available>")
.field("role_selection", &self.role_selection)
.field("allowed_role_metadata", &self.allowed_role_metadata)
Expand All @@ -121,6 +124,12 @@ impl ResolvedAwsBedrock {
serde_json::Value::String(region.clone()),
);
}
if let Some(endpoint_url) = &self.endpoint_url {
options.insert(
"endpoint_url".to_string(),
serde_json::Value::String(endpoint_url.clone()),
);
}
options
}

Expand Down Expand Up @@ -162,6 +171,7 @@ impl<Meta: Clone> UnresolvedAwsBedrock<Meta> {
secret_access_key: self.secret_access_key.clone(),
session_token: self.session_token.clone(),
profile: self.profile.clone(),
endpoint_url: self.endpoint_url.clone(),
role_selection: self.role_selection.clone(),
allowed_role_metadata: self.allowed_role_metadata.clone(),
supported_request_modes: self.supported_request_modes.clone(),
Expand Down Expand Up @@ -215,6 +225,11 @@ impl<Meta: Clone> UnresolvedAwsBedrock<Meta> {
}
}

match self.endpoint_url.as_ref() {
Some(endpoint_url) => env_vars.extend(endpoint_url.required_env_vars()),
None => {}
}

env_vars.extend(self.role_selection.required_env_vars());
env_vars.extend(self.allowed_role_metadata.required_env_vars());
env_vars.extend(self.supported_request_modes.required_env_vars());
Expand Down Expand Up @@ -357,6 +372,18 @@ impl<Meta: Clone> UnresolvedAwsBedrock<Meta> {
},
};

let endpoint_url = match self.endpoint_url.as_ref() {
Some(endpoint_url) => {
let url = endpoint_url.resolve(ctx)?;
if url.is_empty() {
None
} else {
Some(url)
}
Comment on lines +377 to +382

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correctness: endpoint_url is resolved and passed as Some(url) even if the resolved string is empty, which may cause downstream consumers to receive an invalid endpoint.

🤖 AI Agent Prompt for Cursor/Windsurf

📋 Copy this prompt to your AI coding assistant (Cursor, Windsurf, etc.) to get help fixing this issue

In engine/baml-lib/llm-client/src/clients/aws_bedrock.rs, lines 377-382, the code currently checks if the resolved endpoint_url is empty using `is_empty()`, but this will not catch strings that are only whitespace. Update the check to use `url.trim().is_empty()` to ensure that whitespace-only endpoint URLs are also treated as None.
📝 Committable Code Suggestion

‼️ Ensure you review the code suggestion before committing it to the branch. Make sure it replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
let url = endpoint_url.resolve(ctx)?;
if url.is_empty() {
None
} else {
Some(url)
}
let url = endpoint_url.resolve(ctx)?;
if url.trim().is_empty() {
None
} else {
Some(url)
}

}
None => None,
};

#[cfg(target_arch = "wasm32")]
{
if region.is_none() {
Expand All @@ -377,6 +404,7 @@ impl<Meta: Clone> UnresolvedAwsBedrock<Meta> {
secret_access_key,
session_token,
profile,
endpoint_url,
role_selection,
allowed_role_metadata: self.allowed_role_metadata.resolve(ctx)?,
supported_request_modes: self.supported_request_modes.clone(),
Expand Down Expand Up @@ -430,6 +458,9 @@ impl<Meta: Clone> UnresolvedAwsBedrock<Meta> {
let profile = properties
.ensure_string("profile", false)
.map(|(_, v, _)| v.clone());
let endpoint_url = properties
.ensure_string("endpoint_url", false)
.map(|(_, v, _)| v.clone());

let role_selection = properties.ensure_roles_selection();
let allowed_metadata = properties.ensure_allowed_metadata();
Expand Down Expand Up @@ -511,6 +542,7 @@ impl<Meta: Clone> UnresolvedAwsBedrock<Meta> {
secret_access_key,
session_token,
profile,
endpoint_url,
role_selection,
allowed_role_metadata: allowed_metadata,
supported_request_modes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ impl AwsClient {
let config = loader.load().await;
let http_client = custom_http_client::client()?;

let bedrock_config = aws_sdk_bedrockruntime::config::Builder::from(&config)
let mut bedrock_config = aws_sdk_bedrockruntime::config::Builder::from(&config)
// To support HTTPS_PROXY https://github.com/awslabs/aws-sdk-rust/issues/169
.http_client(http_client)
// Adding a custom http client (above) breaks the stalled stream protection for some reason. If a bedrock request takes longer than 5s (the default grace period, it makes it error out), so we disable it.
Expand All @@ -601,9 +601,14 @@ impl AwsClient {
call_stack,
http_request_id.clone(),
&self.properties,
))
.build();
Ok(BedrockRuntimeClient::from_conf(bedrock_config))
));

// Set endpoint_url if specified
if let Some(endpoint_url) = self.properties.endpoint_url.as_ref() {
bedrock_config = bedrock_config.endpoint_url(endpoint_url);
}
Comment on lines +607 to +609

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correctness: bedrock_config.endpoint_url(endpoint_url) does not validate that endpoint_url is a valid URL, so passing an invalid value may cause runtime errors or panics when building the client.

🤖 AI Agent Prompt for Cursor/Windsurf

📋 Copy this prompt to your AI coding assistant (Cursor, Windsurf, etc.) to get help fixing this issue

In engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs, lines 607-609, the code sets the endpoint_url for the Bedrock client without validating that the provided string is a valid URL. This can cause runtime panics or errors if an invalid URL is passed. Please update this block to parse and validate the endpoint_url using http::Uri::parse, and return an error if the URL is invalid, before setting it on the config builder.
📝 Committable Code Suggestion

‼️ Ensure you review the code suggestion before committing it to the branch. Make sure it replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if let Some(endpoint_url) = self.properties.endpoint_url.as_ref() {
bedrock_config = bedrock_config.endpoint_url(endpoint_url);
}
if let Some(endpoint_url) = self.properties.endpoint_url.as_ref() {
if let Err(e) = endpoint_url.parse::<http::Uri>() {
return Err(anyhow::anyhow!("Invalid endpoint_url: {}", e));
}
bedrock_config = bedrock_config.endpoint_url(endpoint_url);
}


Ok(BedrockRuntimeClient::from_conf(bedrock_config.build()))
}

async fn chat_anyhow(&self, response: &ConverseOutput) -> Result<String> {
Expand Down
Loading