Skip to content

Commit cef606f

Browse files
committed
feat: pluggable URL parsing
This way people can customize it to their liking, as there a lot of opinions about this, as evidenced by the comments on GH-34. The default parsing is still the same as before, so new versions don't break existing code. But the user has the option of passing in a settings object, which has a `urlparse` attribute that can be set to a custom function that processes the URL and splits it into a `sockpath` and a `reqpath`. Sem-Ver: feature
1 parent 8449bc0 commit cef606f

File tree

3 files changed

+136
-21
lines changed

3 files changed

+136
-21
lines changed

requests_unixsocket/__init__.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
1-
import requests
21
import sys
32

3+
import requests
4+
from requests.compat import urlparse, unquote
5+
46
from .adapters import UnixAdapter
57

6-
DEFAULT_SCHEME = 'http+unix://'
8+
9+
def default_urlparse(url):
10+
parsed_url = urlparse(url)
11+
return UnixAdapter.Settings.ParseResult(
12+
sockpath=unquote(parsed_url.netloc),
13+
reqpath=parsed_url.path + '?' + parsed_url.query,
14+
)
15+
16+
17+
default_scheme = 'http+unix://'
18+
default_settings = UnixAdapter.Settings(urlparse=default_urlparse)
719

820

921
class Session(requests.Session):
10-
def __init__(self, url_scheme=DEFAULT_SCHEME, *args, **kwargs):
22+
def __init__(self, url_scheme=default_scheme, settings=None,
23+
*args, **kwargs):
1124
super(Session, self).__init__(*args, **kwargs)
12-
self.mount(url_scheme, UnixAdapter())
25+
self.settings = settings or default_settings
26+
self.mount(url_scheme, UnixAdapter(settings=self.settings))
1327

1428

1529
class monkeypatch(object):
16-
def __init__(self, url_scheme=DEFAULT_SCHEME):
30+
def __init__(self, url_scheme=default_scheme):
1731
self.session = Session()
1832
requests = self._get_global_requests_module()
1933

requests_unixsocket/adapters.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import socket
2+
from collections import namedtuple
23

34
from requests.adapters import HTTPAdapter
4-
from requests.compat import urlparse, unquote
5+
from requests.compat import urlparse
56

67
try:
78
import http.client as httplib
@@ -18,16 +19,12 @@
1819
# https://github.com/docker/docker-py/blob/master/docker/transport/unixconn.py
1920
class UnixHTTPConnection(httplib.HTTPConnection, object):
2021

21-
def __init__(self, unix_socket_url, timeout=60):
22-
"""Create an HTTP connection to a unix domain socket
23-
24-
:param unix_socket_url: A URL with a scheme of 'http+unix' and the
25-
netloc is a percent-encoded path to a unix domain socket. E.g.:
26-
'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid'
27-
"""
22+
def __init__(self, url, timeout=60, settings=None):
23+
"""Create an HTTP connection to a unix domain socket"""
2824
super(UnixHTTPConnection, self).__init__('localhost', timeout=timeout)
29-
self.unix_socket_url = unix_socket_url
25+
self.url = url
3026
self.timeout = timeout
27+
self.settings = settings
3128
self.sock = None
3229

3330
def __del__(self): # base class does not have d'tor
@@ -37,27 +34,40 @@ def __del__(self): # base class does not have d'tor
3734
def connect(self):
3835
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
3936
sock.settimeout(self.timeout)
40-
socket_path = unquote(urlparse(self.unix_socket_url).netloc)
41-
sock.connect(socket_path)
37+
sockpath = self.settings.urlparse(self.url).sockpath
38+
sock.connect(sockpath)
4239
self.sock = sock
4340

4441

4542
class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
4643

47-
def __init__(self, socket_path, timeout=60):
44+
def __init__(self, socket_path, timeout=60, settings=None):
4845
super(UnixHTTPConnectionPool, self).__init__(
4946
'localhost', timeout=timeout)
5047
self.socket_path = socket_path
5148
self.timeout = timeout
49+
self.settings = settings
5250

5351
def _new_conn(self):
54-
return UnixHTTPConnection(self.socket_path, self.timeout)
52+
return UnixHTTPConnection(
53+
url=self.socket_path,
54+
timeout=self.timeout,
55+
settings=self.settings,
56+
)
5557

5658

5759
class UnixAdapter(HTTPAdapter):
60+
class Settings(object):
61+
class ParseResult(namedtuple('ParseResult', 'sockpath reqpath')):
62+
pass
63+
64+
def __init__(self, urlparse=None):
65+
self.urlparse = urlparse
5866

59-
def __init__(self, timeout=60, pool_connections=25, *args, **kwargs):
67+
def __init__(self, timeout=60, pool_connections=25, settings=None,
68+
*args, **kwargs):
6069
super(UnixAdapter, self).__init__(*args, **kwargs)
70+
self.settings = settings
6171
self.timeout = timeout
6272
self.pools = urllib3._collections.RecentlyUsedContainer(
6373
pool_connections, dispose_func=lambda p: p.close()
@@ -76,13 +86,17 @@ def get_connection(self, url, proxies=None):
7686
if pool:
7787
return pool
7888

79-
pool = UnixHTTPConnectionPool(url, self.timeout)
89+
pool = UnixHTTPConnectionPool(
90+
socket_path=url,
91+
settings=self.settings,
92+
timeout=self.timeout,
93+
)
8094
self.pools[url] = pool
8195

8296
return pool
8397

8498
def request_url(self, request, proxies):
85-
return request.path_url
99+
return self.settings.urlparse(request.url).reqpath
86100

87101
def close(self):
88102
self.pools.clear()

requests_unixsocket/tests/test_requests_unixsocket.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
"""Tests for requests_unixsocket"""
55

66
import logging
7+
import os
8+
import stat
79

810
import pytest
911
import requests
12+
from requests.compat import urlparse
1013

1114
import requests_unixsocket
1215
from requests_unixsocket.testutils import UnixSocketServerThread
@@ -15,6 +18,35 @@
1518
logger = logging.getLogger(__name__)
1619

1720

21+
def is_socket(path):
22+
try:
23+
mode = os.stat(path).st_mode
24+
return stat.S_ISSOCK(mode)
25+
except OSError:
26+
return False
27+
28+
29+
def get_sock_prefix(path):
30+
"""Keep going up directory tree until we find a socket"""
31+
32+
sockpath = path
33+
reqpath_parts = []
34+
35+
while not is_socket(sockpath):
36+
sockpath, tail = os.path.split(sockpath)
37+
reqpath_parts.append(tail)
38+
39+
return requests_unixsocket.UnixAdapter.Settings.ParseResult(
40+
sockpath=sockpath,
41+
reqpath='/' + os.path.join(*reversed(reqpath_parts)),
42+
)
43+
44+
45+
alt_settings_1 = requests_unixsocket.UnixAdapter.Settings(
46+
urlparse=lambda url: get_sock_prefix(urlparse(url).path),
47+
)
48+
49+
1850
def test_unix_domain_adapter_ok():
1951
with UnixSocketServerThread() as usock_thread:
2052
session = requests_unixsocket.Session('http+unix://')
@@ -41,6 +73,34 @@ def test_unix_domain_adapter_ok():
4173
assert r.text == 'Hello world!'
4274

4375

76+
def test_unix_domain_adapter_alt_settings_1_ok():
77+
with UnixSocketServerThread() as usock_thread:
78+
session = requests_unixsocket.Session(
79+
url_scheme='http+unix://',
80+
settings=alt_settings_1,
81+
)
82+
url = 'http+unix://localhost%s/path/to/page' % usock_thread.usock
83+
84+
for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
85+
'options']:
86+
logger.debug('Calling session.%s(%r) ...', method, url)
87+
r = getattr(session, method)(url)
88+
logger.debug(
89+
'Received response: %r with text: %r and headers: %r',
90+
r, r.text, r.headers)
91+
assert r.status_code == 200
92+
assert r.headers['server'] == 'waitress'
93+
assert r.headers['X-Transport'] == 'unix domain socket'
94+
assert r.headers['X-Requested-Path'] == '/path/to/page'
95+
assert r.headers['X-Socket-Path'] == usock_thread.usock
96+
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
97+
assert r.url.lower() == url.lower()
98+
if method == 'head':
99+
assert r.text == ''
100+
else:
101+
assert r.text == 'Hello world!'
102+
103+
44104
def test_unix_domain_adapter_url_with_query_params():
45105
with UnixSocketServerThread() as usock_thread:
46106
session = requests_unixsocket.Session('http+unix://')
@@ -69,6 +129,33 @@ def test_unix_domain_adapter_url_with_query_params():
69129
assert r.text == 'Hello world!'
70130

71131

132+
def test_unix_domain_adapter_url_with_fragment():
133+
with UnixSocketServerThread() as usock_thread:
134+
session = requests_unixsocket.Session('http+unix://')
135+
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
136+
url = ('http+unix://%s'
137+
'/containers/nginx/logs#some-fragment' % urlencoded_usock)
138+
139+
for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
140+
'options']:
141+
logger.debug('Calling session.%s(%r) ...', method, url)
142+
r = getattr(session, method)(url)
143+
logger.debug(
144+
'Received response: %r with text: %r and headers: %r',
145+
r, r.text, r.headers)
146+
assert r.status_code == 200
147+
assert r.headers['server'] == 'waitress'
148+
assert r.headers['X-Transport'] == 'unix domain socket'
149+
assert r.headers['X-Requested-Path'] == '/containers/nginx/logs'
150+
assert r.headers['X-Socket-Path'] == usock_thread.usock
151+
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
152+
assert r.url.lower() == url.lower()
153+
if method == 'head':
154+
assert r.text == ''
155+
else:
156+
assert r.text == 'Hello world!'
157+
158+
72159
def test_unix_domain_adapter_connection_error():
73160
session = requests_unixsocket.Session('http+unix://')
74161

0 commit comments

Comments
 (0)