Skip to content

Commit 6f9b87e

Browse files
authored
[Tiny-Agent] Fix headers handling + secrets management (#3166)
* [Tiny-Agent] Fix headers handling + secrets management * no whitespace in id * switch to VSCode config file format
1 parent 85d752c commit 6f9b87e

File tree

3 files changed

+19
-51
lines changed

3 files changed

+19
-51
lines changed

src/huggingface_hub/inference/_mcp/agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from .._providers import PROVIDER_OR_POLICY_T
99
from .constants import DEFAULT_SYSTEM_PROMPT, EXIT_LOOP_TOOLS, MAX_NUM_TURNS
10+
from .types import ServerConfig
1011

1112

1213
class Agent(MCPClient):
@@ -40,7 +41,7 @@ def __init__(
4041
self,
4142
*,
4243
model: Optional[str] = None,
43-
servers: Iterable[Dict],
44+
servers: Iterable[ServerConfig],
4445
provider: Optional[PROVIDER_OR_POLICY_T] = None,
4546
base_url: Optional[str] = None,
4647
api_key: Optional[str] = None,
@@ -54,7 +55,7 @@ def __init__(
5455

5556
async def load_tools(self) -> None:
5657
for cfg in self._servers_cfg:
57-
await self.add_mcp_server(cfg["type"], **cfg["config"])
58+
await self.add_mcp_server(**cfg)
5859

5960
async def run(
6061
self,

src/huggingface_hub/inference/_mcp/cli.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,7 @@ def _sigint_handler() -> None:
8585
input_vars = set()
8686
for server in servers:
8787
# Check stdio's "env" and http/sse's "headers" mappings
88-
env_or_headers = (
89-
server["config"].get("env", {})
90-
if server["type"] == "stdio"
91-
else server["config"].get("options", {}).get("requestInit", {}).get("headers", {})
92-
)
88+
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
9389
for key, value in env_or_headers.items():
9490
if env_special_value in value:
9591
input_vars.add(key)
@@ -99,8 +95,9 @@ def _sigint_handler() -> None:
9995
continue
10096

10197
# Prompt user for input
98+
env_variable_key = input_id.replace("-", "_").upper()
10299
print(
103-
f"[blue] • {input_id}[/blue]: {description}. (default: load from {', '.join(sorted(input_vars))}).",
100+
f"[blue] • {input_id}[/blue]: {description}. (default: load from {env_variable_key}).",
104101
end=" ",
105102
)
106103
user_input = (await _async_prompt(exit_event=exit_event)).strip()
@@ -109,23 +106,19 @@ def _sigint_handler() -> None:
109106

110107
# Inject user input (or env variable) into stdio's env or http/sse's headers
111108
for server in servers:
112-
env_or_headers = (
113-
server["config"].get("env", {})
114-
if server["type"] == "stdio"
115-
else server["config"].get("options", {}).get("requestInit", {}).get("headers", {})
116-
)
109+
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
117110
for key, value in env_or_headers.items():
118111
if env_special_value in value:
119112
if user_input:
120113
env_or_headers[key] = env_or_headers[key].replace(env_special_value, user_input)
121114
else:
122-
value_from_env = os.getenv(key, "")
115+
value_from_env = os.getenv(env_variable_key, "")
123116
env_or_headers[key] = env_or_headers[key].replace(env_special_value, value_from_env)
124117
if value_from_env:
125-
print(f"[green]Value successfully loaded from '{key}'[/green]")
118+
print(f"[green]Value successfully loaded from '{env_variable_key}'[/green]")
126119
else:
127120
print(
128-
f"[yellow]No value found for '{key}' in environment variables. Continuing.[/yellow]"
121+
f"[yellow]No value found for '{env_variable_key}' in environment variables. Continuing.[/yellow]"
129122
)
130123

131124
print()
Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,39 @@
11
from typing import Dict, List, Literal, TypedDict, Union
22

33

4-
# Input config
54
class InputConfig(TypedDict, total=False):
65
id: str
76
description: str
87
type: str
98
password: bool
109

1110

12-
# stdio server config
13-
class StdioServerConfig(TypedDict, total=False):
11+
class StdioServerConfig(TypedDict):
12+
type: Literal["stdio"]
1413
command: str
1514
args: List[str]
1615
env: Dict[str, str]
1716
cwd: str
1817

1918

20-
class StdioServer(TypedDict):
21-
type: Literal["stdio"]
22-
config: StdioServerConfig
23-
24-
25-
# http server config
26-
class HTTPRequestInit(TypedDict, total=False):
27-
headers: Dict[str, str]
28-
29-
30-
class HTTPServerOptions(TypedDict, total=False):
31-
requestInit: HTTPRequestInit
32-
sessionId: str
33-
34-
35-
class HTTPServerConfig(TypedDict, total=False):
36-
url: str
37-
options: HTTPServerOptions
38-
39-
40-
class HTTPServer(TypedDict):
19+
class HTTPServerConfig(TypedDict):
4120
type: Literal["http"]
42-
config: HTTPServerConfig
43-
44-
45-
# sse server config
46-
class SSEServerOptions(TypedDict, total=False):
47-
requestInit: HTTPRequestInit
21+
url: str
22+
headers: Dict[str, str]
4823

4924

5025
class SSEServerConfig(TypedDict):
26+
type: Literal["sse"]
5127
url: str
52-
options: SSEServerOptions
28+
headers: Dict[str, str]
5329

5430

55-
class SSEServer(TypedDict):
56-
type: Literal["sse"]
57-
config: SSEServerConfig
31+
ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
5832

5933

6034
# AgentConfig root object
6135
class AgentConfig(TypedDict):
6236
model: str
6337
provider: str
6438
inputs: List[InputConfig]
65-
servers: List[Union[StdioServer, HTTPServer, SSEServer]]
39+
servers: List[ServerConfig]

0 commit comments

Comments
 (0)