Skip to content
6 changes: 4 additions & 2 deletions src/google/adk/tools/google_api_tool/google_api_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,18 @@ def __init__(
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
service_account: Optional[ServiceAccount] = None,
additional_headers: Optional[Dict[str, str]] = None,
):
super().__init__(
name=rest_api_tool.name,
description=rest_api_tool.description,
is_long_running=rest_api_tool.is_long_running,
)
self._rest_api_tool = rest_api_tool
self._rest_api_tool.set_default_headers(additional_headers or {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be if additional_headers: self._rest_api_tool.set_default_headers(additional_headers) works better? To avoid accidentally clean up default_headers when user not set additional_headers in GoogleApiTool.

if service_account is not None:
self.configure_sa_auth(service_account)
else:
elif client_id is not None and client_secret is not None:
self.configure_auth(client_id, client_secret)

@override
Expand All @@ -57,7 +59,7 @@ def _get_declaration(self) -> FunctionDeclaration:

@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
self, *, args: Dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
return await self._rest_api_tool.run_async(
args=args, tool_context=tool_context
Expand Down
11 changes: 10 additions & 1 deletion src/google/adk/tools/google_api_tool/google_api_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from typing import Dict
from typing import List
from typing import Optional
from typing import Union
Expand Down Expand Up @@ -45,6 +46,8 @@ class GoogleApiToolset(BaseToolset):
tool_filter: Optional filter to include only specific tools or use a predicate function.
service_account: Optional service account for authentication.
tool_name_prefix: Optional prefix to add to all tool names in this toolset.
additional_headers: Optional dict of HTTP headers to inject into every request
executed by this toolset.
"""

def __init__(
Expand All @@ -56,13 +59,15 @@ def __init__(
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
service_account: Optional[ServiceAccount] = None,
tool_name_prefix: Optional[str] = None,
additional_headers: Optional[Dict[str, str]] = None,
):
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
self.api_name = api_name
self.api_version = api_version
self._client_id = client_id
self._client_secret = client_secret
self._service_account = service_account
self._additional_headers = dict(additional_headers or {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just do self._additional_headers = additional_headers here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks, I had to change the the type mid way, and I forgot that I already changed it to dict.

self._openapi_toolset = self._load_toolset_with_oidc_auth()

@override
Expand All @@ -72,7 +77,11 @@ async def get_tools(
"""Get all tools in the toolset."""
return [
GoogleApiTool(
tool, self._client_id, self._client_secret, self._service_account
tool,
self._client_id,
self._client_secret,
self._service_account,
self._additional_headers,
)
for tool in await self._openapi_toolset.get_tools(readonly_context)
if self._is_tool_selected(tool, readonly_context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
else operation
)
self.auth_credential, self.auth_scheme = None, None
self._default_headers: Dict[str, str] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are defining it as a private data member, probably better to move it after line 136 (the Private properties section)?


self.configure_auth_credential(auth_credential)
self.configure_auth_scheme(auth_scheme)
Expand Down Expand Up @@ -216,6 +217,10 @@ def configure_auth_credential(
auth_credential = AuthCredential.model_validate_json(auth_credential)
self.auth_credential = auth_credential

def set_default_headers(self, headers: Dict[str, str]):
"""Sets default headers that are merged into every request."""
self._default_headers = dict(headers)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this dict cast here?


def _prepare_auth_request_params(
self,
auth_scheme: AuthScheme,
Expand Down Expand Up @@ -335,6 +340,9 @@ def _prepare_request_params(
k: v for k, v in query_params.items() if v is not None
}

for key, value in self._default_headers.items():
header_params.setdefault(key, value)

request_params: Dict[str, Any] = {
"method": method,
"url": url,
Expand Down
8 changes: 8 additions & 0 deletions tests/unittests/tools/google_api_tool/test_google_api_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def test_init(self, mock_rest_api_tool):
assert tool.is_long_running is False
assert tool._rest_api_tool == mock_rest_api_tool

def test_init_with_additional_headers(self, mock_rest_api_tool):
"""Test GoogleApiTool initialization with additional headers."""
headers = {"developer-token": "test-token"}

GoogleApiTool(mock_rest_api_tool, additional_headers=headers)

mock_rest_api_tool.set_default_headers.assert_called_once_with(headers)

def test_get_declaration(self, mock_rest_api_tool):
"""Test _get_declaration method."""
tool = GoogleApiTool(mock_rest_api_tool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ def test_init(

client_id = "test_client_id"
client_secret = "test_client_secret"
additional_headers = {"developer-token": "abc123"}

tool_set = GoogleApiToolset(
api_name=TEST_API_NAME,
api_version=TEST_API_VERSION,
client_id=client_id,
client_secret=client_secret,
additional_headers=additional_headers,
)

assert tool_set.api_name == TEST_API_NAME
Expand All @@ -141,6 +143,7 @@ def test_init(
assert tool_set._service_account is None
assert tool_set.tool_filter is None
assert tool_set._openapi_toolset == mock_openapi_toolset_instance
assert tool_set._additional_headers == additional_headers

mock_converter_class.assert_called_once_with(
TEST_API_NAME, TEST_API_VERSION
Expand Down Expand Up @@ -191,13 +194,15 @@ async def test_get_tools(
client_id = "cid"
client_secret = "csecret"
sa_mock = mock.MagicMock(spec=ServiceAccount)
additional_headers = {"developer-token": "token"}

tool_set = GoogleApiToolset(
api_name=TEST_API_NAME,
api_version=TEST_API_VERSION,
client_id=client_id,
client_secret=client_secret,
service_account=sa_mock,
additional_headers=additional_headers,
)

tools = await tool_set.get_tools(mock_readonly_context)
Expand All @@ -209,7 +214,7 @@ async def test_get_tools(

for i, rest_tool in enumerate(mock_rest_api_tools):
mock_google_api_tool_class.assert_any_call(
rest_tool, client_id, client_secret, sa_mock
rest_tool, client_id, client_secret, sa_mock, additional_headers
)
assert tools[i] is mock_google_api_tool_instances[i]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,65 @@ def test_prepare_request_params_unknown_parameter(
# Make sure unknown parameters are ignored and do not raise errors.
assert "unknown_param" not in request_params["params"]

def test_prepare_request_params_merges_default_headers(
self,
sample_endpoint,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpoint,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
tool.set_default_headers({"developer-token": "token"})

request_params = tool._prepare_request_params([], {})

assert request_params["headers"]["developer-token"] == "token"

def test_prepare_request_params_preserves_existing_headers(
self,
sample_endpoint,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
sample_api_parameters,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpoint,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
tool.set_default_headers({
"Content-Type": "text/plain",
"developer-token": "token",
"User-Agent": "custom-default",
})

header_param = ApiParameter(
original_name="User-Agent",
py_name="user_agent",
param_location="header",
param_schema=OpenAPISchema(type="string"),
)

params = sample_api_parameters + [header_param]
kwargs = {"test_body_param": "value", "user_agent": "api-client"}

request_params = tool._prepare_request_params(params, kwargs)

assert request_params["headers"]["Content-Type"] == "application/json"
assert request_params["headers"]["developer-token"] == "token"
assert request_params["headers"]["User-Agent"] == "api-client"

def test_prepare_request_params_base_url_handling(
self, sample_auth_credential, sample_auth_scheme, sample_operation
):
Expand Down