-
Notifications
You must be signed in to change notification settings - Fork 2.1k
feat(tools): support additional headers for google api toolset #non-breaking #3194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
78e4d81
3d6d613
71999b0
a09dd4d
d8ded07
37179a6
ace3e5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,16 +39,18 @@ def __init__( | |
client_id: Optional[str] = None, | ||
client_secret: Optional[str] = None, | ||
service_account: Optional[ServiceAccount] = None, | ||
additional_headers: Optional[Dict[str, str]] = None, | ||
): | ||
super().__init__( | ||
name=rest_api_tool.name, | ||
description=rest_api_tool.description, | ||
is_long_running=rest_api_tool.is_long_running, | ||
) | ||
self._rest_api_tool = rest_api_tool | ||
self._rest_api_tool.set_default_headers(additional_headers or {}) | ||
|
||
if service_account is not None: | ||
self.configure_sa_auth(service_account) | ||
else: | ||
elif client_id is not None and client_secret is not None: | ||
self.configure_auth(client_id, client_secret) | ||
|
||
@override | ||
|
@@ -57,7 +59,7 @@ def _get_declaration(self) -> FunctionDeclaration: | |
|
||
@override | ||
async def run_async( | ||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext] | ||
self, *, args: Dict[str, Any], tool_context: Optional[ToolContext] | ||
) -> Dict[str, Any]: | ||
return await self._rest_api_tool.run_async( | ||
args=args, tool_context=tool_context | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
|
||
from __future__ import annotations | ||
|
||
from typing import Dict | ||
from typing import List | ||
from typing import Optional | ||
from typing import Union | ||
|
@@ -45,6 +46,8 @@ class GoogleApiToolset(BaseToolset): | |
tool_filter: Optional filter to include only specific tools or use a predicate function. | ||
service_account: Optional service account for authentication. | ||
tool_name_prefix: Optional prefix to add to all tool names in this toolset. | ||
additional_headers: Optional dict of HTTP headers to inject into every request | ||
executed by this toolset. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -56,13 +59,15 @@ def __init__( | |
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, | ||
service_account: Optional[ServiceAccount] = None, | ||
tool_name_prefix: Optional[str] = None, | ||
additional_headers: Optional[Dict[str, str]] = None, | ||
): | ||
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix) | ||
self.api_name = api_name | ||
self.api_version = api_version | ||
self._client_id = client_id | ||
self._client_secret = client_secret | ||
self._service_account = service_account | ||
self._additional_headers = dict(additional_headers or {}) | ||
|
||
self._openapi_toolset = self._load_toolset_with_oidc_auth() | ||
|
||
@override | ||
|
@@ -72,7 +77,11 @@ async def get_tools( | |
"""Get all tools in the toolset.""" | ||
return [ | ||
GoogleApiTool( | ||
tool, self._client_id, self._client_secret, self._service_account | ||
tool, | ||
self._client_id, | ||
self._client_secret, | ||
self._service_account, | ||
self._additional_headers, | ||
) | ||
for tool in await self._openapi_toolset.get_tools(readonly_context) | ||
if self._is_tool_selected(tool, readonly_context) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,6 +128,7 @@ def __init__( | |
else operation | ||
) | ||
self.auth_credential, self.auth_scheme = None, None | ||
self._default_headers: Dict[str, str] = {} | ||
Prhmma marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
self.configure_auth_credential(auth_credential) | ||
self.configure_auth_scheme(auth_scheme) | ||
|
@@ -216,6 +217,10 @@ def configure_auth_credential( | |
auth_credential = AuthCredential.model_validate_json(auth_credential) | ||
self.auth_credential = auth_credential | ||
|
||
def set_default_headers(self, headers: Dict[str, str]): | ||
Prhmma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Sets default headers that are merged into every request.""" | ||
self._default_headers = dict(headers) | ||
|
||
|
||
def _prepare_auth_request_params( | ||
self, | ||
auth_scheme: AuthScheme, | ||
|
@@ -335,6 +340,9 @@ def _prepare_request_params( | |
k: v for k, v in query_params.items() if v is not None | ||
} | ||
|
||
for key, value in self._default_headers.items(): | ||
header_params.setdefault(key, value) | ||
|
||
request_params: Dict[str, Any] = { | ||
"method": method, | ||
"url": url, | ||
|
Uh oh!
There was an error while loading. Please reload this page.