Skip to content

Commit 68e2380

Browse files
committed
Swapped ChromeDevToolsClient out for Any in many places to resolve typing issue.
1 parent 39bfc81 commit 68e2380

File tree

10 files changed

+393
-400
lines changed

10 files changed

+393
-400
lines changed

src/client.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
#!/usr/bin/env python3
2+
"""Chrome DevTools Protocol Client
3+
4+
This module contains the ChromeDevToolsClient class that manages WebSocket connections
5+
to Chrome's remote debugging interface and handles CDP commands and events.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import asyncio
11+
import json
12+
import logging
13+
import os
14+
from collections.abc import Callable
15+
from typing import Any
16+
17+
import aiohttp
18+
import websockets
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class ChromeDevToolsClient:
24+
"""
25+
Chrome DevTools Protocol client with WebSocket communication capabilities.
26+
27+
This class manages the connection to Chrome's remote debugging interface,
28+
handles event processing, and executes CDP commands. It maintains state
29+
for network requests and console logs, and provides a robust interface
30+
for web application debugging.
31+
32+
The client automatically discovers available Chrome targets and establishes
33+
WebSocket connections for real-time communication with the browser.
34+
35+
Attributes:
36+
port: Chrome remote debugging port (default: 9222)
37+
host: Hostname for Chrome connection (default: localhost)
38+
ws: WebSocket connection to Chrome DevTools
39+
connected: Connection status flag
40+
message_id: Incremental ID for CDP messages
41+
pending_messages: Awaiting responses for sent commands
42+
event_handlers: Registered handlers for CDP events
43+
network_requests: Captured network request data
44+
console_logs: Captured console log entries
45+
"""
46+
47+
def __init__(self, port: int = 9222, host: str = "localhost") -> None:
48+
"""
49+
Initialise the Chrome DevTools Protocol client.
50+
51+
Args:
52+
port: Chrome remote debugging port (overridden by CHROME_DEBUG_PORT env var)
53+
host: Hostname for Chrome connection
54+
"""
55+
# Use environment variable if available for flexible configuration
56+
env_port = os.getenv("CHROME_DEBUG_PORT")
57+
if env_port and env_port.isdigit():
58+
port = int(env_port)
59+
60+
self.port = port
61+
self.host = host
62+
self.ws: websockets.WebSocketServerProtocol | None = None # type: ignore
63+
self.connected = False
64+
self.message_id = 0
65+
self.pending_messages: dict[int, asyncio.Future] = {}
66+
self.event_handlers: dict[str, list[Callable[[dict[str, Any]], None]]] = {}
67+
68+
# Storage for captured browser data
69+
self.network_requests: list[dict[str, Any]] = []
70+
self.console_logs: list[dict[str, Any]] = []
71+
72+
async def connect(self) -> bool:
73+
"""
74+
Establish connection to Chrome DevTools via WebSocket.
75+
76+
Discovers available Chrome targets and connects to the first available
77+
target using WebSocket communication. Starts the message handling loop
78+
for processing incoming CDP events and responses.
79+
80+
Returns:
81+
bool: True if connection successful, False otherwise
82+
83+
Raises:
84+
ConnectionError: If no browser targets are available
85+
"""
86+
try:
87+
targets = await self._get_available_targets()
88+
if not targets:
89+
raise ConnectionError("No browser targets available")
90+
91+
target = targets[0]
92+
ws_url = target["webSocketDebuggerUrl"]
93+
94+
self.ws = await websockets.connect(ws_url)
95+
self.connected = True
96+
97+
asyncio.create_task(self._handle_incoming_messages())
98+
99+
logger.info(f"Connected to Chrome target: {target.get('title', 'Unknown')}")
100+
return True
101+
102+
except Exception as e:
103+
logger.error(f"Failed to connect to Chrome: {e}")
104+
self.connected = False
105+
return False
106+
107+
async def disconnect(self) -> None:
108+
"""Gracefully disconnect from Chrome DevTools."""
109+
if self.ws:
110+
await self.ws.close()
111+
self.connected = False
112+
self.ws = None
113+
logger.info("Disconnected from Chrome")
114+
115+
async def _get_available_targets(self) -> list[dict[str, Any]]:
116+
"""Retrieve list of available Chrome targets."""
117+
try:
118+
async with aiohttp.ClientSession() as session:
119+
async with session.get(f"http://{self.host}:{self.port}/json") as response:
120+
if response.status == 200:
121+
targets = await response.json()
122+
return [t for t in targets if t.get("type") == "page"]
123+
return []
124+
except Exception as e:
125+
logger.error(f"Failed to get targets: {e}")
126+
return []
127+
128+
async def send_command(
129+
self, method: str, params: dict[str, Any] | None = None
130+
) -> dict[str, Any]:
131+
"""Send a command to Chrome DevTools and wait for response."""
132+
if not self.connected or not self.ws:
133+
raise ConnectionError("Not connected to Chrome")
134+
135+
self.message_id += 1
136+
message = {"id": self.message_id, "method": method, "params": params or {}}
137+
138+
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
139+
self.pending_messages[self.message_id] = future
140+
141+
try:
142+
await self.ws.send(json.dumps(message))
143+
result = await asyncio.wait_for(future, timeout=10.0)
144+
return result # type: ignore
145+
except asyncio.TimeoutError:
146+
if self.message_id in self.pending_messages:
147+
del self.pending_messages[self.message_id]
148+
raise TimeoutError(f"Command {method} timed out") from None
149+
except Exception as e:
150+
if self.message_id in self.pending_messages:
151+
del self.pending_messages[self.message_id]
152+
raise e
153+
154+
async def _handle_incoming_messages(self) -> None:
155+
"""Handle incoming WebSocket messages from Chrome."""
156+
try:
157+
if self.ws is not None:
158+
async for message in self.ws:
159+
try:
160+
data = json.loads(message)
161+
162+
if "id" in data:
163+
message_id = data["id"]
164+
if message_id in self.pending_messages:
165+
future = self.pending_messages.pop(message_id)
166+
if "error" in data:
167+
future.set_exception(Exception(data["error"]["message"]))
168+
else:
169+
future.set_result(data.get("result", {}))
170+
171+
elif "method" in data:
172+
await self._process_event(data)
173+
174+
except json.JSONDecodeError:
175+
logger.warning("Received invalid JSON from Chrome")
176+
except Exception as e:
177+
logger.error(f"Error processing message: {e}")
178+
179+
except websockets.exceptions.ConnectionClosed:
180+
logger.info("Chrome connection closed")
181+
self.connected = False
182+
except Exception as e:
183+
logger.error(f"Error in message handler: {e}")
184+
self.connected = False
185+
186+
async def _process_event(self, event: dict[str, Any]) -> None:
187+
"""Process CDP event notifications and store relevant data."""
188+
method = event["method"]
189+
params = event.get("params", {})
190+
191+
if method == "Network.requestWillBeSent":
192+
await self._process_network_request(params)
193+
elif method == "Network.responseReceived":
194+
await self._process_network_response(params)
195+
elif method == "Network.loadingFinished":
196+
await self._process_network_completion(params)
197+
elif method == "Network.loadingFailed":
198+
await self._process_network_failure(params)
199+
elif method == "Runtime.consoleAPICalled":
200+
await self._process_console_message(params)
201+
elif method == "Runtime.exceptionThrown":
202+
await self._process_console_exception(params)
203+
204+
if method in self.event_handlers:
205+
for handler in self.event_handlers[method]:
206+
try:
207+
if asyncio.iscoroutinefunction(handler):
208+
await handler(params)
209+
else:
210+
handler(params)
211+
except Exception as e:
212+
logger.error(f"Error in event handler for {method}: {e}")
213+
214+
async def _process_network_request(self, params: dict[str, Any]) -> None:
215+
"""Process network request event."""
216+
from .tools.utils import safe_timestamp_conversion
217+
218+
self.network_requests.append(
219+
{
220+
"requestId": params["requestId"],
221+
"url": params["request"]["url"],
222+
"method": params["request"]["method"],
223+
"headers": params["request"].get("headers", {}),
224+
"timestamp": safe_timestamp_conversion(params["timestamp"]),
225+
"type": "request",
226+
"status": "pending",
227+
}
228+
)
229+
230+
async def _process_network_response(self, params: dict[str, Any]) -> None:
231+
"""Process network response event."""
232+
from .tools.utils import safe_timestamp_conversion
233+
234+
request_id = params["requestId"]
235+
for req in self.network_requests:
236+
if req.get("requestId") == request_id and req["type"] == "request":
237+
req.update(
238+
{
239+
"response": {
240+
"status": params["response"]["status"],
241+
"statusText": params["response"]["statusText"],
242+
"headers": params["response"]["headers"],
243+
"mimeType": params["response"]["mimeType"],
244+
"timestamp": safe_timestamp_conversion(params["timestamp"]),
245+
"remoteIPAddress": params["response"].get("remoteIPAddress"),
246+
"protocol": params["response"].get("protocol"),
247+
},
248+
"status": "responded",
249+
}
250+
)
251+
break
252+
253+
async def _process_network_completion(self, params: dict[str, Any]) -> None:
254+
"""Process network loading completion event."""
255+
request_id = params["requestId"]
256+
for req in self.network_requests:
257+
if req.get("requestId") == request_id:
258+
req.update(
259+
{"status": "completed", "encodedDataLength": params.get("encodedDataLength")}
260+
)
261+
break
262+
263+
async def _process_network_failure(self, params: dict[str, Any]) -> None:
264+
"""Process network loading failure event."""
265+
request_id = params["requestId"]
266+
for req in self.network_requests:
267+
if req.get("requestId") == request_id:
268+
req.update(
269+
{
270+
"status": "failed",
271+
"errorText": params.get("errorText"),
272+
"cancelled": params.get("canceled", False),
273+
}
274+
)
275+
break
276+
277+
async def _process_console_message(self, params: dict[str, Any]) -> None:
278+
"""Process console API call event."""
279+
from .tools.utils import safe_timestamp_conversion
280+
281+
self.console_logs.append(
282+
{
283+
"type": params["type"],
284+
"args": [arg.get("value", str(arg)) for arg in params["args"]],
285+
"timestamp": safe_timestamp_conversion(params["timestamp"]),
286+
"executionContextId": params.get("executionContextId"),
287+
"stackTrace": params.get("stackTrace"),
288+
}
289+
)
290+
291+
async def _process_console_exception(self, params: dict[str, Any]) -> None:
292+
"""Process console exception event."""
293+
from .tools.utils import safe_timestamp_conversion
294+
295+
exception = params["exceptionDetails"]
296+
self.console_logs.append(
297+
{
298+
"type": "error",
299+
"args": [exception.get("text", "Unknown error")],
300+
"timestamp": safe_timestamp_conversion(params["timestamp"]),
301+
"executionContextId": exception.get("executionContextId"),
302+
"stackTrace": exception.get("stackTrace"),
303+
"exception": True,
304+
}
305+
)
306+
307+
async def enable_domains(self) -> None:
308+
"""Enable necessary CDP domains for functionality."""
309+
domains = [
310+
"Network",
311+
"Runtime",
312+
"Page",
313+
"Performance",
314+
"DOM",
315+
"CSS",
316+
"Security",
317+
"DOMStorage",
318+
]
319+
for domain in domains:
320+
try:
321+
await self.send_command(f"{domain}.enable")
322+
logger.info(f"{domain} domain enabled")
323+
except Exception as e:
324+
logger.warning(f"Failed to enable {domain} domain: {e}")
325+
326+
async def get_target_info(self) -> dict[str, Any]:
327+
"""Get information about the current target."""
328+
try:
329+
return await self.send_command("Target.getTargetInfo")
330+
except Exception:
331+
return {"title": "Unknown", "url": "Unknown"}
332+
333+
def add_event_handler(
334+
self, event_method: str, handler: Callable[[dict[str, Any]], None]
335+
) -> None:
336+
"""Register an event handler for a specific CDP event."""
337+
if event_method not in self.event_handlers:
338+
self.event_handlers[event_method] = []
339+
self.event_handlers[event_method].append(handler)

0 commit comments

Comments
 (0)