@@ -29,18 +29,23 @@ func init() {
29
29
}
30
30
}
31
31
32
- func isPrivateIP ( ip net.IP ) bool {
33
- for _ , block := range privateIPBlocks {
32
+ func blocksContain ( blocks [] * net. IPNet , ip net.IP ) bool {
33
+ for _ , block := range blocks {
34
34
if block .Contains (ip ) {
35
35
return true
36
36
}
37
37
}
38
38
return false
39
39
}
40
40
41
+ func isPrivateIP (ip net.IP ) bool {
42
+ return blocksContain (privateIPBlocks , ip )
43
+ }
44
+
41
45
type noLocalTransport struct {
42
- inner http.RoundTripper
43
- errlog logrus.FieldLogger
46
+ inner http.RoundTripper
47
+ errlog logrus.FieldLogger
48
+ allowedBlocks []* net.IPNet
44
49
}
45
50
46
51
func (no noLocalTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
@@ -61,34 +66,38 @@ func (no noLocalTransport) RoundTrip(req *http.Request) (*http.Response, error)
61
66
return
62
67
}
63
68
69
+ if blocksContain (no .allowedBlocks , ip ) {
70
+ return
71
+ }
72
+
64
73
if isPrivateIP (ip ) {
65
74
cancel ()
66
75
no .errlog .Error ("Cancelled attempted request to ip in private range" )
67
76
return
68
77
}
69
-
70
78
},
71
79
})
72
80
73
81
req = req .WithContext (ctx )
74
82
return no .inner .RoundTrip (req )
75
83
}
76
84
77
- func SafeRountripper (trans http.RoundTripper , log logrus.FieldLogger ) http.RoundTripper {
85
+ func SafeRountripper (trans http.RoundTripper , log logrus.FieldLogger , allowedBlocks ... * net. IPNet ) http.RoundTripper {
78
86
if trans == nil {
79
87
trans = http .DefaultTransport
80
88
}
81
89
82
90
ret := & noLocalTransport {
83
- inner : trans ,
84
- errlog : log .WithField ("transport" , "local_blocker" ),
91
+ inner : trans ,
92
+ errlog : log .WithField ("transport" , "local_blocker" ),
93
+ allowedBlocks : allowedBlocks ,
85
94
}
86
95
87
96
return ret
88
97
}
89
98
90
- func SafeHTTPClient (client * http.Client , log logrus.FieldLogger ) * http.Client {
91
- client .Transport = SafeRountripper (client .Transport , log )
99
+ func SafeHTTPClient (client * http.Client , log logrus.FieldLogger , allowedBlocks ... * net. IPNet ) * http.Client {
100
+ client .Transport = SafeRountripper (client .Transport , log , allowedBlocks ... )
92
101
93
102
return client
94
103
}
0 commit comments