Skip to content

Commit 8eafec9

Browse files
committed
Allow whitelisting local IPs for development
1 parent 0648439 commit 8eafec9

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

http/http.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,23 @@ func init() {
2929
}
3030
}
3131

32-
func isPrivateIP(ip net.IP) bool {
33-
for _, block := range privateIPBlocks {
32+
func blocksContain(blocks []*net.IPNet, ip net.IP) bool {
33+
for _, block := range blocks {
3434
if block.Contains(ip) {
3535
return true
3636
}
3737
}
3838
return false
3939
}
4040

41+
func isPrivateIP(ip net.IP) bool {
42+
return blocksContain(privateIPBlocks, ip)
43+
}
44+
4145
type noLocalTransport struct {
42-
inner http.RoundTripper
43-
errlog logrus.FieldLogger
46+
inner http.RoundTripper
47+
errlog logrus.FieldLogger
48+
allowedBlocks []*net.IPNet
4449
}
4550

4651
func (no noLocalTransport) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -61,34 +66,38 @@ func (no noLocalTransport) RoundTrip(req *http.Request) (*http.Response, error)
6166
return
6267
}
6368

69+
if blocksContain(no.allowedBlocks, ip) {
70+
return
71+
}
72+
6473
if isPrivateIP(ip) {
6574
cancel()
6675
no.errlog.Error("Cancelled attempted request to ip in private range")
6776
return
6877
}
69-
7078
},
7179
})
7280

7381
req = req.WithContext(ctx)
7482
return no.inner.RoundTrip(req)
7583
}
7684

77-
func SafeRountripper(trans http.RoundTripper, log logrus.FieldLogger) http.RoundTripper {
85+
func SafeRountripper(trans http.RoundTripper, log logrus.FieldLogger, allowedBlocks ...*net.IPNet) http.RoundTripper {
7886
if trans == nil {
7987
trans = http.DefaultTransport
8088
}
8189

8290
ret := &noLocalTransport{
83-
inner: trans,
84-
errlog: log.WithField("transport", "local_blocker"),
91+
inner: trans,
92+
errlog: log.WithField("transport", "local_blocker"),
93+
allowedBlocks: allowedBlocks,
8594
}
8695

8796
return ret
8897
}
8998

90-
func SafeHTTPClient(client *http.Client, log logrus.FieldLogger) *http.Client {
91-
client.Transport = SafeRountripper(client.Transport, log)
99+
func SafeHTTPClient(client *http.Client, log logrus.FieldLogger, allowedBlocks ...*net.IPNet) *http.Client {
100+
client.Transport = SafeRountripper(client.Transport, log, allowedBlocks...)
92101

93102
return client
94103
}

http/http_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,23 @@ func TestSafeHTTPClient(t *testing.T) {
4242

4343
client := SafeHTTPClient(&http.Client{}, logrus.New())
4444

45-
// It blocks the local IP
45+
// It blocks the local IP.
4646
_, err = client.Get(ts.URL)
4747
assert.NotNil(t, err)
4848

49-
// It blocks localhost
49+
// It blocks localhost.
5050
_, err = client.Get("http://localhost:" + tsURL.Port())
5151
assert.NotNil(t, err)
5252

5353
// It succeeds when the local IP range used by the testserver is removed from
5454
// the blacklist.
5555
ipNet := popMatchingBlock(net.ParseIP(tsURL.Hostname()))
56-
defer func() {
57-
privateIPBlocks = append(privateIPBlocks, ipNet)
58-
}()
56+
_, err = client.Get(ts.URL)
57+
assert.Nil(t, err)
58+
privateIPBlocks = append(privateIPBlocks, ipNet)
5959

60+
// It allows whitelisting for local development.
61+
client = SafeHTTPClient(&http.Client{}, logrus.New(), ipNet)
6062
_, err = client.Get(ts.URL)
6163
assert.Nil(t, err)
6264
}

0 commit comments

Comments
 (0)