diff --git a/src/google/adk/tools/google_api_tool/google_api_tool.py b/src/google/adk/tools/google_api_tool/google_api_tool.py index d2bac5686d..04d1ebb4b6 100644 --- a/src/google/adk/tools/google_api_tool/google_api_tool.py +++ b/src/google/adk/tools/google_api_tool/google_api_tool.py @@ -39,6 +39,8 @@ def __init__( client_id: Optional[str] = None, client_secret: Optional[str] = None, service_account: Optional[ServiceAccount] = None, + *, + additional_headers: Optional[Dict[str, str]] = None, ): super().__init__( name=rest_api_tool.name, @@ -46,9 +48,11 @@ def __init__( is_long_running=rest_api_tool.is_long_running, ) self._rest_api_tool = rest_api_tool + if additional_headers: + self._rest_api_tool.set_default_headers(additional_headers) if service_account is not None: self.configure_sa_auth(service_account) - else: + elif client_id is not None and client_secret is not None: self.configure_auth(client_id, client_secret) @override @@ -57,7 +61,7 @@ def _get_declaration(self) -> FunctionDeclaration: @override async def run_async( - self, *, args: dict[str, Any], tool_context: Optional[ToolContext] + self, *, args: Dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: return await self._rest_api_tool.run_async( args=args, tool_context=tool_context diff --git a/src/google/adk/tools/google_api_tool/google_api_toolset.py b/src/google/adk/tools/google_api_tool/google_api_toolset.py index 7e5de3e595..714e654229 100644 --- a/src/google/adk/tools/google_api_tool/google_api_toolset.py +++ b/src/google/adk/tools/google_api_tool/google_api_toolset.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Dict from typing import List from typing import Optional from typing import Union @@ -45,6 +46,8 @@ class GoogleApiToolset(BaseToolset): tool_filter: Optional filter to include only specific tools or use a predicate function. service_account: Optional service account for authentication. tool_name_prefix: Optional prefix to add to all tool names in this toolset. + additional_headers: Optional dict of HTTP headers to inject into every request + executed by this toolset. """ def __init__( @@ -56,6 +59,8 @@ def __init__( tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, service_account: Optional[ServiceAccount] = None, tool_name_prefix: Optional[str] = None, + *, + additional_headers: Optional[Dict[str, str]] = None, ): super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix) self.api_name = api_name @@ -63,6 +68,7 @@ def __init__( self._client_id = client_id self._client_secret = client_secret self._service_account = service_account + self._additional_headers = additional_headers self._openapi_toolset = self._load_toolset_with_oidc_auth() @override @@ -72,7 +78,11 @@ async def get_tools( """Get all tools in the toolset.""" return [ GoogleApiTool( - tool, self._client_id, self._client_secret, self._service_account + tool, + self._client_id, + self._client_secret, + self._service_account, + additional_headers=self._additional_headers, ) for tool in await self._openapi_toolset.get_tools(readonly_context) if self._is_tool_selected(tool, readonly_context) diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 2c02d55510..2f16e8ba87 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -134,6 +134,7 @@ def __init__( # Private properties self.credential_exchanger = AutoAuthCredentialExchanger() + self._default_headers: Dict[str, str] = {} if should_parse_operation: self._operation_parser = OperationParser(self.operation) @@ -216,6 +217,10 @@ def configure_auth_credential( auth_credential = AuthCredential.model_validate_json(auth_credential) self.auth_credential = auth_credential + def set_default_headers(self, headers: Dict[str, str]): + """Sets default headers that are merged into every request.""" + self._default_headers = headers + def _prepare_auth_request_params( self, auth_scheme: AuthScheme, @@ -335,6 +340,9 @@ def _prepare_request_params( k: v for k, v in query_params.items() if v is not None } + for key, value in self._default_headers.items(): + header_params.setdefault(key, value) + request_params: Dict[str, Any] = { "method": method, "url": url, diff --git a/tests/unittests/tools/google_api_tool/test_google_api_tool.py b/tests/unittests/tools/google_api_tool/test_google_api_tool.py index 0d9c1f9efb..9e4761fe0a 100644 --- a/tests/unittests/tools/google_api_tool/test_google_api_tool.py +++ b/tests/unittests/tools/google_api_tool/test_google_api_tool.py @@ -56,6 +56,14 @@ def test_init(self, mock_rest_api_tool): assert tool.is_long_running is False assert tool._rest_api_tool == mock_rest_api_tool + def test_init_with_additional_headers(self, mock_rest_api_tool): + """Test GoogleApiTool initialization with additional headers.""" + headers = {"developer-token": "test-token"} + + GoogleApiTool(mock_rest_api_tool, additional_headers=headers) + + mock_rest_api_tool.set_default_headers.assert_called_once_with(headers) + def test_get_declaration(self, mock_rest_api_tool): """Test _get_declaration method.""" tool = GoogleApiTool(mock_rest_api_tool) diff --git a/tests/unittests/tools/google_api_tool/test_google_api_toolset.py b/tests/unittests/tools/google_api_tool/test_google_api_toolset.py index 9dc89ff69c..5da0cb4bcb 100644 --- a/tests/unittests/tools/google_api_tool/test_google_api_toolset.py +++ b/tests/unittests/tools/google_api_tool/test_google_api_toolset.py @@ -126,12 +126,14 @@ def test_init( client_id = "test_client_id" client_secret = "test_client_secret" + additional_headers = {"developer-token": "abc123"} tool_set = GoogleApiToolset( api_name=TEST_API_NAME, api_version=TEST_API_VERSION, client_id=client_id, client_secret=client_secret, + additional_headers=additional_headers, ) assert tool_set.api_name == TEST_API_NAME @@ -141,6 +143,7 @@ def test_init( assert tool_set._service_account is None assert tool_set.tool_filter is None assert tool_set._openapi_toolset == mock_openapi_toolset_instance + assert tool_set._additional_headers == additional_headers mock_converter_class.assert_called_once_with( TEST_API_NAME, TEST_API_VERSION @@ -191,6 +194,7 @@ async def test_get_tools( client_id = "cid" client_secret = "csecret" sa_mock = mock.MagicMock(spec=ServiceAccount) + additional_headers = {"developer-token": "token"} tool_set = GoogleApiToolset( api_name=TEST_API_NAME, @@ -198,6 +202,7 @@ async def test_get_tools( client_id=client_id, client_secret=client_secret, service_account=sa_mock, + additional_headers=additional_headers, ) tools = await tool_set.get_tools(mock_readonly_context) @@ -209,7 +214,11 @@ async def test_get_tools( for i, rest_tool in enumerate(mock_rest_api_tools): mock_google_api_tool_class.assert_any_call( - rest_tool, client_id, client_secret, sa_mock + rest_tool, + client_id, + client_secret, + sa_mock, + additional_headers=additional_headers, ) assert tools[i] is mock_google_api_tool_instances[i] diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index c4cbea7b9b..02b496bc53 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -686,6 +686,65 @@ def test_prepare_request_params_unknown_parameter( # Make sure unknown parameters are ignored and do not raise errors. assert "unknown_param" not in request_params["params"] + def test_prepare_request_params_merges_default_headers( + self, + sample_endpoint, + sample_auth_credential, + sample_auth_scheme, + sample_operation, + ): + tool = RestApiTool( + name="test_tool", + description="Test Tool", + endpoint=sample_endpoint, + operation=sample_operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + tool.set_default_headers({"developer-token": "token"}) + + request_params = tool._prepare_request_params([], {}) + + assert request_params["headers"]["developer-token"] == "token" + + def test_prepare_request_params_preserves_existing_headers( + self, + sample_endpoint, + sample_auth_credential, + sample_auth_scheme, + sample_operation, + sample_api_parameters, + ): + tool = RestApiTool( + name="test_tool", + description="Test Tool", + endpoint=sample_endpoint, + operation=sample_operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + tool.set_default_headers({ + "Content-Type": "text/plain", + "developer-token": "token", + "User-Agent": "custom-default", + }) + + header_param = ApiParameter( + original_name="User-Agent", + py_name="user_agent", + param_location="header", + param_schema=OpenAPISchema(type="string"), + ) + + params = sample_api_parameters + [header_param] + kwargs = {"test_body_param": "value", "user_agent": "api-client"} + + request_params = tool._prepare_request_params(params, kwargs) + + assert request_params["headers"]["Content-Type"] == "application/json" + assert request_params["headers"]["developer-token"] == "token" + assert request_params["headers"]["User-Agent"] == "api-client" + def test_prepare_request_params_base_url_handling( self, sample_auth_credential, sample_auth_scheme, sample_operation ):