Skip to content

Commit 84e50ab

Browse files
authored
fix: properly handle urls in saferoundtripper (#267)
* debug logging * tdd: added test case for failure behavior * now attempting to resolve url
1 parent 8cfaa49 commit 84e50ab

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

http/http.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@ func init() {
2929
}
3030
}
3131

32-
func blocksContain(blocks []*net.IPNet, ip net.IP) bool {
32+
func blocksContainsAny(blocks []*net.IPNet, ips []net.IP) bool {
3333
for _, block := range blocks {
34-
if block.Contains(ip) {
35-
return true
34+
for _, ip := range ips {
35+
if block.Contains(ip) {
36+
return true
37+
}
3638
}
3739
}
3840
return false
3941
}
4042

41-
func isPrivateIP(ip net.IP) bool {
42-
return blocksContain(privateIPBlocks, ip)
43+
func containsPrivateIP(ips []net.IP) bool {
44+
return blocksContainsAny(privateIPBlocks, ips)
4345
}
4446

4547
type noLocalTransport struct {
@@ -59,18 +61,19 @@ func (no noLocalTransport) RoundTrip(req *http.Request) (*http.Response, error)
5961
no.errlog.WithError(err).Error("Cancelled request due to error in address parsing")
6062
return
6163
}
62-
ip := net.ParseIP(host)
63-
if ip == nil {
64+
65+
ips, err := net.LookupIP(host)
66+
if err != nil || len(ips) == 0 {
6467
cancel()
65-
no.errlog.WithError(err).Error("Cancelled request due to error in ip parsing")
68+
no.errlog.WithError(err).Error("Cancelled request due to error in host lookup")
6669
return
6770
}
6871

69-
if blocksContain(no.allowedBlocks, ip) {
72+
if blocksContainsAny(no.allowedBlocks, ips) {
7073
return
7174
}
7275

73-
if isPrivateIP(ip) {
76+
if containsPrivateIP(ips) {
7477
cancel()
7578
no.errlog.Error("Cancelled attempted request to ip in private range")
7679
return

http/http_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ func TestIsPrivateIP(t *testing.T) {
2727

2828
for _, tt := range tests {
2929
ip := net.ParseIP(tt.ip)
30-
assert.Equal(t, tt.expected, isPrivateIP(ip))
30+
if ip == nil {
31+
require.Fail(t, "failed to parse IP")
32+
}
33+
assert.Equal(t, tt.expected, containsPrivateIP([]net.IP{ip}))
3134
}
3235
}
3336

@@ -43,6 +46,10 @@ func TestSafeHTTPClient(t *testing.T) {
4346

4447
client := SafeHTTPClient(&http.Client{}, logrus.New())
4548

49+
// It allows accessing non-local addresses
50+
_, err = client.Get("https://google.com")
51+
require.Nil(t, err)
52+
4653
// It blocks the local IP.
4754
_, err = client.Get(ts.URL)
4855
require.NotNil(t, err)

0 commit comments

Comments
 (0)