Skip to content

Commit 3996881

Browse files
committed
Add Headers support
1 parent 31bc909 commit 3996881

File tree

10 files changed

+111
-52
lines changed

10 files changed

+111
-52
lines changed

cmd/scan.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ var scanCmd = &cobra.Command{
3535
outputFile := cmd.Flag("output").Value.String()
3636
outputFormat := utils.DetectOutputFormat(outputFile)
3737

38+
headers, _ := cmd.Flags().GetStringArray("header")
39+
3840
opts := scanner.ScanOptions{
3941
URL: cmd.Flag("url").Value.String(),
4042
File: cmd.Flag("file").Value.String(),
@@ -45,6 +47,7 @@ var scanCmd = &cobra.Command{
4547
Verbose: mustBool(cmd.Flags().GetBool("verbose")),
4648
ScanMode: cmd.Flag("mode").Value.String(),
4749
PluginList: cmd.Flag("plugin-list").Value.String(),
50+
Headers: headers,
4851
}
4952

5053
if opts.URL == "" && opts.File == "" {
@@ -66,6 +69,8 @@ func init() {
6669
scanCmd.Flags().StringP("mode", "m", "stealthy", "Scan mode: stealthy, bruteforce, or hybrid")
6770
scanCmd.Flags().
6871
StringP("plugin-list", "p", "", "Path to a custom plugin list file for bruteforce mode")
72+
scanCmd.Flags().
73+
StringArrayP("header", "H", []string{}, "HTTP header to include in requests. Can be specified multiple times.")
6974
}
7075

7176
func mustBool(value bool, err error) bool {

internal/scanner/bruteforce.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func BruteforcePlugins(
5555
plugins []string,
5656
threads int,
5757
progress *utils.ProgressManager,
58+
headers []string,
5859
) []string {
5960
if len(plugins) == 0 {
6061
utils.DefaultLogger.Warning("No plugins provided for brute-force scan")
@@ -88,7 +89,7 @@ func BruteforcePlugins(
8889
progress.SetMessage(fmt.Sprintf("🔎 Bruteforcing plugin %-30.30s", p))
8990
}
9091

91-
version := utils.GetPluginVersion(normalized, p, threads)
92+
version := utils.GetPluginVersion(normalized, p, threads, headers)
9293
if version != "" && version != "unknown" {
9394
if progress != nil {
9495
progress.ClearLine()
@@ -120,9 +121,10 @@ func HybridScan(
120121
bruteforcePlugins []string,
121122
threads int,
122123
progress *utils.ProgressManager,
124+
headers []string,
123125
) []string {
124126
if len(stealthyPlugins) == 0 {
125-
return BruteforcePlugins(target, bruteforcePlugins, threads, progress)
127+
return BruteforcePlugins(target, bruteforcePlugins, threads, progress, headers)
126128
}
127129

128130
detectedMap := make(map[string]bool, len(stealthyPlugins))
@@ -137,7 +139,7 @@ func HybridScan(
137139
}
138140
}
139141

140-
brutefound := BruteforcePlugins(target, remaining, threads, progress)
142+
brutefound := BruteforcePlugins(target, remaining, threads, progress, headers)
141143
result := make([]string, len(stealthyPlugins), len(stealthyPlugins)+len(brutefound))
142144
copy(result, stealthyPlugins)
143145
return append(result, brutefound...)

internal/scanner/endpoints.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ func fetchEndpointsFromPath(target, path string, httpClient *utils.HTTPClientMan
5151
return endpoints
5252
}
5353

54-
func FetchEndpoints(target string) []string {
55-
httpClient := utils.NewHTTPClient(10 * time.Second)
54+
func FetchEndpoints(target string, headers []string) []string {
55+
httpClient := utils.NewHTTPClient(10*time.Second, headers)
5656

5757
endpointsChan := make(chan []string, 2)
5858
var wg sync.WaitGroup

internal/scanner/endpoints_test.go

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,53 +31,71 @@ import (
3131
func TestFetchEndpoints(t *testing.T) {
3232
tests := []struct {
3333
name string
34-
target string
3534
mockServer func(w http.ResponseWriter, r *http.Request)
35+
headers []string
3636
want []string
3737
}{
3838
{
39-
name: "Valid response with routes",
39+
name: "Valid response with routes and header present",
4040
mockServer: func(w http.ResponseWriter, r *http.Request) {
41+
if r.Header.Get("X-Test") != "value" {
42+
w.WriteHeader(http.StatusForbidden)
43+
return
44+
}
4145
response := map[string]interface{}{
4246
"routes": map[string]interface{}{
4347
"/wp/v2/posts": nil,
4448
"/wp/v2/comments": nil,
4549
"/wp/v2/categories": nil,
4650
},
4751
}
48-
if err := json.NewEncoder(w).Encode(response); err != nil {
49-
t.Errorf("Failed to encode JSON response: %v", err)
52+
_ = json.NewEncoder(w).Encode(response)
53+
},
54+
headers: []string{"X-Test: value"},
55+
want: []string{"/wp/v2/posts", "/wp/v2/comments", "/wp/v2/categories"},
56+
},
57+
{
58+
name: "Valid response with routes but header missing",
59+
mockServer: func(w http.ResponseWriter, r *http.Request) {
60+
if r.Header.Get("X-Test") == "" {
61+
response := map[string]interface{}{
62+
"routes": map[string]interface{}{
63+
"/wp/v2/posts": nil,
64+
},
65+
}
66+
_ = json.NewEncoder(w).Encode(response)
67+
return
5068
}
69+
w.WriteHeader(http.StatusInternalServerError)
5170
},
52-
want: []string{"/wp/v2/posts", "/wp/v2/comments", "/wp/v2/categories"},
71+
headers: nil,
72+
want: []string{"/wp/v2/posts"},
5373
},
5474
{
5575
name: "Response without routes",
5676
mockServer: func(w http.ResponseWriter, r *http.Request) {
57-
response := map[string]interface{}{
77+
_ = json.NewEncoder(w).Encode(map[string]interface{}{
5878
"data": "No routes here",
59-
}
60-
if err := json.NewEncoder(w).Encode(response); err != nil {
61-
t.Errorf("Failed to encode JSON response: %v", err)
62-
}
79+
})
6380
},
64-
want: []string{},
81+
headers: nil,
82+
want: []string{},
6583
},
6684
{
6785
name: "Invalid JSON response",
6886
mockServer: func(w http.ResponseWriter, r *http.Request) {
69-
if _, err := w.Write([]byte("{invalid-json")); err != nil {
70-
t.Errorf("Failed to write invalid JSON: %v", err)
71-
}
87+
_, _ = w.Write([]byte("{invalid-json"))
7288
},
73-
want: []string{},
89+
headers: nil,
90+
want: []string{},
7491
},
7592
{
7693
name: "HTTP error response",
7794
mockServer: func(w http.ResponseWriter, r *http.Request) {
7895
w.WriteHeader(http.StatusInternalServerError)
7996
},
80-
want: []string{},
97+
headers: nil,
98+
want: []string{},
8199
},
82100
}
83101

@@ -86,7 +104,7 @@ func TestFetchEndpoints(t *testing.T) {
86104
server := httptest.NewServer(http.HandlerFunc(tt.mockServer))
87105
defer server.Close()
88106

89-
got := FetchEndpoints(server.URL)
107+
got := FetchEndpoints(server.URL, tt.headers)
90108

91109
sort.Strings(got)
92110
sort.Strings(tt.want)

internal/scanner/scan.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type ScanOptions struct {
3737
Verbose bool
3838
ScanMode string
3939
PluginList string
40+
Headers []string
4041
}
4142

4243
func ScanTargets(opts ScanOptions) {
@@ -114,7 +115,7 @@ func performStealthyScan(
114115
return nil, PluginDetectionResult{}
115116
}
116117

117-
endpoints := FetchEndpoints(target)
118+
endpoints := FetchEndpoints(target, opts.Headers)
118119
if len(endpoints) == 0 {
119120
if opts.File == "" {
120121
utils.DefaultLogger.Warning("No REST endpoints found on " + target)
@@ -153,7 +154,7 @@ func performBruteforceScan(
153154
}
154155
}
155156

156-
detected := BruteforcePlugins(target, plugins, threads, pb)
157+
detected := BruteforcePlugins(target, plugins, threads, pb, opts.Headers)
157158

158159
pr := PluginDetectionResult{
159160
Plugins: make(map[string]*PluginData, len(detected)),
@@ -209,7 +210,7 @@ func performHybridScan(
209210
bruteBar = utils.NewProgressBar(len(remaining), "🔎 Bruteforcing remaining")
210211
}
211212

212-
brutefound := BruteforcePlugins(target, remaining, threads, bruteBar)
213+
brutefound := BruteforcePlugins(target, remaining, threads, bruteBar, opts.Headers)
213214

214215
if bruteBar != nil {
215216
bruteBar.Finish()
@@ -306,7 +307,7 @@ func ScanSite(
306307

307308
version := "unknown"
308309
if !opts.NoCheckVersion {
309-
version = utils.GetPluginVersion(target, pl, opts.Threads)
310+
version = utils.GetPluginVersion(target, pl, opts.Threads, opts.Headers)
310311
}
311312

312313
var matched []wordfence.Vulnerability

internal/utils/http.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ const maxRedirects = 10
4141
type HTTPClientManager struct {
4242
client *http.Client
4343
userAgent string
44+
headers []string
4445
}
4546

46-
func NewHTTPClient(timeout time.Duration) *HTTPClientManager {
47+
func NewHTTPClient(timeout time.Duration, headers []string) *HTTPClientManager {
4748
transport := &http.Transport{
4849
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
4950
DisableKeepAlives: true,
@@ -63,6 +64,7 @@ func NewHTTPClient(timeout time.Duration) *HTTPClientManager {
6364
return &HTTPClientManager{
6465
client: client,
6566
userAgent: uarand.GetRandom(),
67+
headers: headers,
6668
}
6769
}
6870

@@ -72,7 +74,35 @@ func (h *HTTPClientManager) Get(url string) (string, error) {
7274
return "", errors.New("failed to create request: " + err.Error())
7375
}
7476

75-
req.Header.Set("User-Agent", h.userAgent)
77+
hasUA := false
78+
for _, hdr := range h.headers {
79+
parts := strings.SplitN(hdr, ":", 2)
80+
if len(parts) != 2 {
81+
continue
82+
}
83+
key := strings.TrimSpace(parts[0])
84+
if strings.EqualFold(key, "User-Agent") {
85+
req.Header.Add("User-Agent", strings.TrimSpace(parts[1]))
86+
hasUA = true
87+
}
88+
}
89+
90+
if !hasUA {
91+
req.Header.Set("User-Agent", h.userAgent)
92+
}
93+
94+
for _, hdr := range h.headers {
95+
parts := strings.SplitN(hdr, ":", 2)
96+
if len(parts) != 2 {
97+
continue
98+
}
99+
key := strings.TrimSpace(parts[0])
100+
if strings.EqualFold(key, "User-Agent") {
101+
continue
102+
}
103+
value := strings.TrimSpace(parts[1])
104+
req.Header.Add(key, value)
105+
}
76106

77107
resp, err := h.client.Do(req)
78108
if err != nil {
@@ -117,20 +147,14 @@ func (h *HTTPClientManager) Get(url string) (string, error) {
117147
return string(data), nil
118148
}
119149

120-
// NormalizeURL ensures the URL has the correct format (removes trailing slash)
121150
func NormalizeURL(url string) string {
122-
// Remove trailing slash if present
123151
url = strings.TrimSuffix(url, "/")
124-
125-
// Ensure URL has http:// or https:// prefix
126152
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
127153
url = "https://" + url
128154
}
129-
130155
return url
131156
}
132157

133-
// SplitLines splits a byte array into lines
134158
func SplitLines(data []byte) []string {
135159
var lines []string
136160
scanner := bufio.NewScanner(bytes.NewReader(data))

internal/utils/http_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func TestHTTPClientManager_Get(t *testing.T) {
7979
mockServer := httptest.NewServer(tt.serverFunc)
8080
defer mockServer.Close()
8181

82-
client := NewHTTPClient(5 * time.Second)
82+
client := NewHTTPClient(5*time.Second, nil)
8383

8484
got, err := client.Get(mockServer.URL)
8585

internal/utils/logger_test.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,25 @@
1717
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
1818
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919

20+
// Copyright (c) 2025 Valentin Lobstein (Chocapikk) <balgogan@protonmail.com>
21+
//
22+
// Permission is hereby granted, free of charge, to any person obtaining a copy of
23+
// this software and associated documentation files (the "Software"), to deal in
24+
// the Software without restriction, including without limitation the rights to
25+
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
26+
// the Software, and to permit persons to whom the Software is furnished to do so,
27+
// subject to the following conditions:
28+
//
29+
// The above copyright notice and this permission notice shall be included in all
30+
// copies or substantial portions of the Software.
31+
//
32+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
34+
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
35+
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
36+
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
37+
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
38+
2039
package utils
2140

2241
import (
@@ -45,7 +64,7 @@ func TestLogger_Info(t *testing.T) {
4564
var buf bytes.Buffer
4665
originalLogger := DefaultLogger.Logger
4766
DefaultLogger.Logger = log.New(&buf, "", 0)
48-
defer func() { DefaultLogger.Logger = originalLogger }() // Restaure l'ancien Logger
67+
defer func() { DefaultLogger.Logger = originalLogger }()
4968

5069
msg := "This is an info message"
5170
DefaultLogger.Info(msg)
@@ -105,14 +124,13 @@ func TestLogger_PrintBanner(t *testing.T) {
105124
version := "v1.0.0"
106125
isLatest := true
107126
DefaultLogger.PrintBanner(version, isLatest)
108-
defer func() { _ = w.Close() }()
127+
_ = w.Close()
109128

110129
var outBuf bytes.Buffer
111130
_, _ = outBuf.ReadFrom(r)
112131
os.Stdout = originalStdout
113132

114133
output := outBuf.String()
115-
116134
if !strings.Contains(output, version) || !strings.Contains(output, "latest") {
117135
t.Errorf("PrintBanner() output = %v, want version %v and 'latest'", output, version)
118136
}
@@ -121,14 +139,13 @@ func TestLogger_PrintBanner(t *testing.T) {
121139
os.Stdout = w
122140

123141
DefaultLogger.PrintBanner(version, false)
142+
_ = w.Close()
124143

125-
defer func() { _ = w.Close() }()
126144
outBuf.Reset()
127145
_, _ = outBuf.ReadFrom(r)
128146
os.Stdout = originalStdout
129147

130148
output = outBuf.String()
131-
132149
if !strings.Contains(output, "outdated") {
133150
t.Errorf("PrintBanner() output = %v, want 'outdated'", output)
134151
}

internal/utils/version.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ func CheckLatestVersion(currentVersion string) (string, bool) {
7070
return latest.String(), curr.Compare(latest) >= 0
7171
}
7272

73-
func GetPluginVersion(target, plugin string, _ int) string {
74-
httpClient := NewHTTPClient(10 * time.Second)
73+
func GetPluginVersion(target, plugin string, _ int, headers []string) string {
74+
httpClient := NewHTTPClient(10*time.Second, headers)
7575
return fetchVersionFromReadme(httpClient, target, plugin)
7676
}
7777

0 commit comments

Comments
 (0)