Skip to content

Commit 3d89331

Browse files
committed
feat: Allow resolving IP to specific server on run-all-tests
1 parent 24467a3 commit 3d89331

File tree

2 files changed

+124
-30
lines changed

2 files changed

+124
-30
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ This is needed when testing on edge, since `.../check` times out the workload.
310310
In conjunction with this, there is a convenience script which runs all tests each in a separate query.
311311

312312
So you may also run `./run-all-tests-via-api.py --host <hostname> --port <port>`.
313+
If you need to hit a specific IP while preserving the original hostname (e.g., for edge testing or custom DNS), use `--resolve-ip <ip>` which is SNI-compatible for HTTPS and sets the HTTP `Host` header accordingly.
313314
This is intended to be run to validate package functionaltiy on edge, as each test becomes a separate workload.
314315

315316
### Notes

run-all-tests-via-api.py

Lines changed: 123 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,121 @@
66
import sys
77
import tempfile
88
import time
9+
import socket
10+
import ssl
11+
import http.client
912
from pathlib import Path
10-
from typing import List, Tuple
11-
from urllib.error import HTTPError, URLError
12-
from urllib.parse import quote
13-
from urllib.request import Request, urlopen
13+
from typing import List, Optional, Tuple
14+
from urllib.parse import quote, urlparse
1415

1516

16-
def fetch_tests(base_url: str) -> List[str]:
17-
with urlopen(base_url + "/list", timeout=20) as resp:
18-
if resp.status != 200:
19-
raise RuntimeError(f"/list returned HTTP {resp.status}")
20-
payload = json.loads(resp.read().decode("utf-8"))
21-
tests = payload.get("tests", [])
22-
if not isinstance(tests, list):
23-
raise RuntimeError("Invalid /list payload: missing 'tests' list")
24-
return [str(t) for t in tests]
17+
def _default_port_for_scheme(scheme: str) -> int:
18+
return 443 if scheme == "https" else 80
2519

2620

27-
def run_single_test(base_url: str, test_name: str, timeout: float) -> Tuple[bool, str]:
21+
def _host_header(hostname: str, port: Optional[int], scheme: str) -> str:
22+
default_port = _default_port_for_scheme(scheme)
23+
if port and port != default_port:
24+
return f"{hostname}:{port}"
25+
return hostname
26+
27+
28+
class _ResolvedHTTPConnection(http.client.HTTPConnection):
29+
def __init__(self, resolved_host: str, port: int, timeout: float) -> None:
30+
super().__init__(host=resolved_host, port=port, timeout=timeout)
31+
self._resolved_host = resolved_host
32+
33+
# For HTTP we just connect to the resolved host as usual.
34+
35+
36+
class _ResolvedHTTPSConnection(http.client.HTTPSConnection):
37+
def __init__(self, resolved_host: str, port: int, timeout: float, *, server_hostname: str, context: Optional[ssl.SSLContext] = None) -> None:
38+
# Pass the resolved host/IP to the base class so it doesn't try to resolve
39+
super().__init__(host=resolved_host, port=port, timeout=timeout, context=context)
40+
self._resolved_host = resolved_host
41+
self._server_hostname = server_hostname
42+
43+
def connect(self) -> None:
44+
# Largely mirrors the stdlib implementation but pins the TCP connect
45+
# to the resolved host/IP and sets SNI to the original hostname.
46+
self.sock = socket.create_connection((self._resolved_host, self.port), self.timeout, self.source_address)
47+
if self._tunnel_host:
48+
self._tunnel()
49+
# Ensure we have a context
50+
if self._context is None:
51+
self._context = ssl.create_default_context()
52+
# Enable hostname checking by default
53+
self._context.check_hostname = True
54+
self.sock = self._context.wrap_socket(self.sock, server_hostname=self._server_hostname)
55+
56+
57+
def http_get_text(url: str, *, resolve_ip: Optional[str], timeout: float) -> Tuple[int, str]:
58+
"""Perform a GET request with optional DNS override and SNI support.
59+
60+
If resolve_ip is given, connects to that IP, sets the Host header to the
61+
original hostname, and (for HTTPS) uses SNI with the original hostname.
62+
Returns (status_code, text_body).
63+
"""
64+
parsed = urlparse(url)
65+
scheme = (parsed.scheme or "http").lower()
66+
hostname = parsed.hostname or ""
67+
port = parsed.port or _default_port_for_scheme(scheme)
68+
path = parsed.path or "/"
69+
if parsed.query:
70+
path += f"?{parsed.query}"
71+
72+
headers = {
73+
"Accept": "*/*",
74+
}
75+
76+
# If we override resolution, set the Host header explicitly
77+
if resolve_ip:
78+
headers["Host"] = _host_header(hostname, parsed.port, scheme)
79+
80+
try:
81+
if resolve_ip:
82+
if scheme == "https":
83+
context = ssl.create_default_context()
84+
conn = _ResolvedHTTPSConnection(resolve_ip, port, timeout, server_hostname=hostname, context=context)
85+
else:
86+
conn = _ResolvedHTTPConnection(resolve_ip, port, timeout)
87+
else:
88+
# No override — use stdlib conveniences
89+
if scheme == "https":
90+
conn = http.client.HTTPSConnection(hostname, port, timeout=timeout)
91+
else:
92+
conn = http.client.HTTPConnection(hostname, port, timeout=timeout)
93+
94+
conn.request("GET", path, headers=headers)
95+
resp = conn.getresponse()
96+
data = resp.read().decode("utf-8", errors="replace")
97+
status = resp.status
98+
conn.close()
99+
return status, data
100+
except Exception as e:
101+
# Normalize into a network error string like urllib would give
102+
return 0, f"Network error calling {url}: {e}"
103+
104+
105+
def fetch_tests(base_url: str, *, resolve_ip: Optional[str]) -> List[str]:
106+
status, body = http_get_text(base_url + "/list", resolve_ip=resolve_ip, timeout=20)
107+
if status != 200:
108+
raise RuntimeError(f"/list returned HTTP {status}")
109+
payload = json.loads(body)
110+
tests = payload.get("tests", [])
111+
if not isinstance(tests, list):
112+
raise RuntimeError("Invalid /list payload: missing 'tests' list")
113+
return [str(t) for t in tests]
114+
115+
116+
def run_single_test(base_url: str, test_name: str, timeout: float, *, resolve_ip: Optional[str]) -> Tuple[bool, str]:
28117
url = base_url + "/check/" + quote(test_name)
29-
req = Request(url, method="GET")
30118
try:
31119
print(f"Checking: {url}")
32-
with urlopen(req, timeout=timeout) as resp:
33-
output = resp.read().decode("utf-8", errors="replace")
34-
ok = resp.status == 200
35-
return ok, output
36-
except HTTPError as e:
37-
try:
38-
body = e.read().decode("utf-8", errors="replace")
39-
except Exception:
40-
body = str(e)
41-
# 417 or 500 considered failure
42-
return False, body
43-
except URLError as e:
120+
status, output = http_get_text(url, resolve_ip=resolve_ip, timeout=timeout)
121+
ok = status == 200
122+
return ok, output
123+
except Exception as e:
44124
return False, f"Network error calling {url}: {e}"
45125

46126

@@ -72,13 +152,26 @@ def main() -> int:
72152
help="Timeout for each test in seconds (default: 30.0)",
73153
)
74154

155+
parser.add_argument(
156+
"--resolve-ip",
157+
default=os.environ.get("RESOLVE_IP"),
158+
help=(
159+
"Optional IP to resolve the server hostname to. "
160+
"When set, connections go to this IP while preserving the original "
161+
"hostname for HTTP Host and TLS SNI (SNI-compatible)."
162+
),
163+
)
164+
75165
args = parser.parse_args()
76166

77-
base_url = f"{args.host}:{args.port}"
167+
host_value = args.host
168+
if not (host_value.startswith("http://") or host_value.startswith("https://")):
169+
host_value = "http://" + host_value
170+
base_url = f"{host_value}:{args.port}"
78171
outdir = Path(args.outdir)
79172
outdir.mkdir(parents=True, exist_ok=True)
80173

81-
tests = fetch_tests(base_url)
174+
tests = fetch_tests(base_url, resolve_ip=args.resolve_ip)
82175
if not tests:
83176
print("No tests returned by /list. Nothing to run.")
84177
return 0
@@ -92,7 +185,7 @@ def main() -> int:
92185

93186
for idx, test in enumerate(tests, start=1):
94187
print(f"[{idx}/{len(tests)}] Running {test} ...", flush=True)
95-
ok, output = run_single_test(base_url, test, timeout=args.test_timeout)
188+
ok, output = run_single_test(base_url, test, timeout=args.test_timeout, resolve_ip=args.resolve_ip)
96189
# Write log
97190
safe_name = test.replace(os.sep, "_")
98191
log_path = outdir / f"{safe_name}.log"

0 commit comments

Comments
 (0)