Skip to content

Commit 0cad3b3

Browse files
committed
Add Headers support
1 parent 31bc909 commit 0cad3b3

File tree

10 files changed

+92
-52
lines changed

10 files changed

+92
-52
lines changed

cmd/scan.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ var scanCmd = &cobra.Command{
3535
outputFile := cmd.Flag("output").Value.String()
3636
outputFormat := utils.DetectOutputFormat(outputFile)
3737

38+
headers, _ := cmd.Flags().GetStringArray("header")
39+
3840
opts := scanner.ScanOptions{
3941
URL: cmd.Flag("url").Value.String(),
4042
File: cmd.Flag("file").Value.String(),
@@ -45,6 +47,7 @@ var scanCmd = &cobra.Command{
4547
Verbose: mustBool(cmd.Flags().GetBool("verbose")),
4648
ScanMode: cmd.Flag("mode").Value.String(),
4749
PluginList: cmd.Flag("plugin-list").Value.String(),
50+
Headers: headers,
4851
}
4952

5053
if opts.URL == "" && opts.File == "" {
@@ -66,6 +69,8 @@ func init() {
6669
scanCmd.Flags().StringP("mode", "m", "stealthy", "Scan mode: stealthy, bruteforce, or hybrid")
6770
scanCmd.Flags().
6871
StringP("plugin-list", "p", "", "Path to a custom plugin list file for bruteforce mode")
72+
scanCmd.Flags().
73+
StringArrayP("header", "H", []string{}, "HTTP header to include in requests. Can be specified multiple times.")
6974
}
7075

7176
func mustBool(value bool, err error) bool {

internal/scanner/bruteforce.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func BruteforcePlugins(
5555
plugins []string,
5656
threads int,
5757
progress *utils.ProgressManager,
58+
headers []string,
5859
) []string {
5960
if len(plugins) == 0 {
6061
utils.DefaultLogger.Warning("No plugins provided for brute-force scan")
@@ -88,7 +89,7 @@ func BruteforcePlugins(
8889
progress.SetMessage(fmt.Sprintf("🔎 Bruteforcing plugin %-30.30s", p))
8990
}
9091

91-
version := utils.GetPluginVersion(normalized, p, threads)
92+
version := utils.GetPluginVersion(normalized, p, threads, headers)
9293
if version != "" && version != "unknown" {
9394
if progress != nil {
9495
progress.ClearLine()
@@ -120,9 +121,10 @@ func HybridScan(
120121
bruteforcePlugins []string,
121122
threads int,
122123
progress *utils.ProgressManager,
124+
headers []string,
123125
) []string {
124126
if len(stealthyPlugins) == 0 {
125-
return BruteforcePlugins(target, bruteforcePlugins, threads, progress)
127+
return BruteforcePlugins(target, bruteforcePlugins, threads, progress, headers)
126128
}
127129

128130
detectedMap := make(map[string]bool, len(stealthyPlugins))
@@ -137,7 +139,7 @@ func HybridScan(
137139
}
138140
}
139141

140-
brutefound := BruteforcePlugins(target, remaining, threads, progress)
142+
brutefound := BruteforcePlugins(target, remaining, threads, progress, headers)
141143
result := make([]string, len(stealthyPlugins), len(stealthyPlugins)+len(brutefound))
142144
copy(result, stealthyPlugins)
143145
return append(result, brutefound...)

internal/scanner/endpoints.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ func fetchEndpointsFromPath(target, path string, httpClient *utils.HTTPClientMan
5151
return endpoints
5252
}
5353

54-
func FetchEndpoints(target string) []string {
55-
httpClient := utils.NewHTTPClient(10 * time.Second)
54+
func FetchEndpoints(target string, headers []string) []string {
55+
httpClient := utils.NewHTTPClient(10*time.Second, headers)
5656

5757
endpointsChan := make(chan []string, 2)
5858
var wg sync.WaitGroup

internal/scanner/endpoints_test.go

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,53 +31,71 @@ import (
3131
func TestFetchEndpoints(t *testing.T) {
3232
tests := []struct {
3333
name string
34-
target string
3534
mockServer func(w http.ResponseWriter, r *http.Request)
35+
headers []string
3636
want []string
3737
}{
3838
{
39-
name: "Valid response with routes",
39+
name: "Valid response with routes and header present",
4040
mockServer: func(w http.ResponseWriter, r *http.Request) {
41+
if r.Header.Get("X-Test") != "value" {
42+
w.WriteHeader(http.StatusForbidden)
43+
return
44+
}
4145
response := map[string]interface{}{
4246
"routes": map[string]interface{}{
4347
"/wp/v2/posts": nil,
4448
"/wp/v2/comments": nil,
4549
"/wp/v2/categories": nil,
4650
},
4751
}
48-
if err := json.NewEncoder(w).Encode(response); err != nil {
49-
t.Errorf("Failed to encode JSON response: %v", err)
52+
_ = json.NewEncoder(w).Encode(response)
53+
},
54+
headers: []string{"X-Test: value"},
55+
want: []string{"/wp/v2/posts", "/wp/v2/comments", "/wp/v2/categories"},
56+
},
57+
{
58+
name: "Valid response with routes but header missing",
59+
mockServer: func(w http.ResponseWriter, r *http.Request) {
60+
if r.Header.Get("X-Test") == "" {
61+
response := map[string]interface{}{
62+
"routes": map[string]interface{}{
63+
"/wp/v2/posts": nil,
64+
},
65+
}
66+
_ = json.NewEncoder(w).Encode(response)
67+
return
5068
}
69+
w.WriteHeader(http.StatusInternalServerError)
5170
},
52-
want: []string{"/wp/v2/posts", "/wp/v2/comments", "/wp/v2/categories"},
71+
headers: nil,
72+
want: []string{"/wp/v2/posts"},
5373
},
5474
{
5575
name: "Response without routes",
5676
mockServer: func(w http.ResponseWriter, r *http.Request) {
57-
response := map[string]interface{}{
77+
_ = json.NewEncoder(w).Encode(map[string]interface{}{
5878
"data": "No routes here",
59-
}
60-
if err := json.NewEncoder(w).Encode(response); err != nil {
61-
t.Errorf("Failed to encode JSON response: %v", err)
62-
}
79+
})
6380
},
64-
want: []string{},
81+
headers: nil,
82+
want: []string{},
6583
},
6684
{
6785
name: "Invalid JSON response",
6886
mockServer: func(w http.ResponseWriter, r *http.Request) {
69-
if _, err := w.Write([]byte("{invalid-json")); err != nil {
70-
t.Errorf("Failed to write invalid JSON: %v", err)
71-
}
87+
_, _ = w.Write([]byte("{invalid-json"))
7288
},
73-
want: []string{},
89+
headers: nil,
90+
want: []string{},
7491
},
7592
{
7693
name: "HTTP error response",
7794
mockServer: func(w http.ResponseWriter, r *http.Request) {
7895
w.WriteHeader(http.StatusInternalServerError)
7996
},
80-
want: []string{},
97+
headers: nil,
98+
want: []string{},
8199
},
82100
}
83101

@@ -86,7 +104,7 @@ func TestFetchEndpoints(t *testing.T) {
86104
server := httptest.NewServer(http.HandlerFunc(tt.mockServer))
87105
defer server.Close()
88106

89-
got := FetchEndpoints(server.URL)
107+
got := FetchEndpoints(server.URL, tt.headers)
90108

91109
sort.Strings(got)
92110
sort.Strings(tt.want)

internal/scanner/scan.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type ScanOptions struct {
3737
Verbose bool
3838
ScanMode string
3939
PluginList string
40+
Headers []string
4041
}
4142

4243
func ScanTargets(opts ScanOptions) {
@@ -114,7 +115,7 @@ func performStealthyScan(
114115
return nil, PluginDetectionResult{}
115116
}
116117

117-
endpoints := FetchEndpoints(target)
118+
endpoints := FetchEndpoints(target, opts.Headers)
118119
if len(endpoints) == 0 {
119120
if opts.File == "" {
120121
utils.DefaultLogger.Warning("No REST endpoints found on " + target)
@@ -153,7 +154,7 @@ func performBruteforceScan(
153154
}
154155
}
155156

156-
detected := BruteforcePlugins(target, plugins, threads, pb)
157+
detected := BruteforcePlugins(target, plugins, threads, pb, opts.Headers)
157158

158159
pr := PluginDetectionResult{
159160
Plugins: make(map[string]*PluginData, len(detected)),
@@ -209,7 +210,7 @@ func performHybridScan(
209210
bruteBar = utils.NewProgressBar(len(remaining), "🔎 Bruteforcing remaining")
210211
}
211212

212-
brutefound := BruteforcePlugins(target, remaining, threads, bruteBar)
213+
brutefound := BruteforcePlugins(target, remaining, threads, bruteBar, opts.Headers)
213214

214215
if bruteBar != nil {
215216
bruteBar.Finish()
@@ -306,7 +307,7 @@ func ScanSite(
306307

307308
version := "unknown"
308309
if !opts.NoCheckVersion {
309-
version = utils.GetPluginVersion(target, pl, opts.Threads)
310+
version = utils.GetPluginVersion(target, pl, opts.Threads, opts.Headers)
310311
}
311312

312313
var matched []wordfence.Vulnerability

internal/utils/http.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ const maxRedirects = 10
4141
type HTTPClientManager struct {
4242
client *http.Client
4343
userAgent string
44+
headers []string
4445
}
4546

46-
func NewHTTPClient(timeout time.Duration) *HTTPClientManager {
47+
func NewHTTPClient(timeout time.Duration, headers []string) *HTTPClientManager {
4748
transport := &http.Transport{
4849
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
4950
DisableKeepAlives: true,
@@ -63,6 +64,7 @@ func NewHTTPClient(timeout time.Duration) *HTTPClientManager {
6364
return &HTTPClientManager{
6465
client: client,
6566
userAgent: uarand.GetRandom(),
67+
headers: headers,
6668
}
6769
}
6870

@@ -72,7 +74,35 @@ func (h *HTTPClientManager) Get(url string) (string, error) {
7274
return "", errors.New("failed to create request: " + err.Error())
7375
}
7476

75-
req.Header.Set("User-Agent", h.userAgent)
77+
hasUA := false
78+
for _, hdr := range h.headers {
79+
parts := strings.SplitN(hdr, ":", 2)
80+
if len(parts) != 2 {
81+
continue
82+
}
83+
key := strings.TrimSpace(parts[0])
84+
if strings.EqualFold(key, "User-Agent") {
85+
req.Header.Add("User-Agent", strings.TrimSpace(parts[1]))
86+
hasUA = true
87+
}
88+
}
89+
90+
if !hasUA {
91+
req.Header.Set("User-Agent", h.userAgent)
92+
}
93+
94+
for _, hdr := range h.headers {
95+
parts := strings.SplitN(hdr, ":", 2)
96+
if len(parts) != 2 {
97+
continue
98+
}
99+
key := strings.TrimSpace(parts[0])
100+
if strings.EqualFold(key, "User-Agent") {
101+
continue
102+
}
103+
value := strings.TrimSpace(parts[1])
104+
req.Header.Add(key, value)
105+
}
76106

77107
resp, err := h.client.Do(req)
78108
if err != nil {
@@ -117,20 +147,14 @@ func (h *HTTPClientManager) Get(url string) (string, error) {
117147
return string(data), nil
118148
}
119149

120-
// NormalizeURL ensures the URL has the correct format (removes trailing slash)
121150
func NormalizeURL(url string) string {
122-
// Remove trailing slash if present
123151
url = strings.TrimSuffix(url, "/")
124-
125-
// Ensure URL has http:// or https:// prefix
126152
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
127153
url = "https://" + url
128154
}
129-
130155
return url
131156
}
132157

133-
// SplitLines splits a byte array into lines
134158
func SplitLines(data []byte) []string {
135159
var lines []string
136160
scanner := bufio.NewScanner(bytes.NewReader(data))

internal/utils/http_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func TestHTTPClientManager_Get(t *testing.T) {
7979
mockServer := httptest.NewServer(tt.serverFunc)
8080
defer mockServer.Close()
8181

82-
client := NewHTTPClient(5 * time.Second)
82+
client := NewHTTPClient(5*time.Second, nil)
8383

8484
got, err := client.Get(mockServer.URL)
8585

internal/utils/logger_test.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestLogger_Info(t *testing.T) {
4545
var buf bytes.Buffer
4646
originalLogger := DefaultLogger.Logger
4747
DefaultLogger.Logger = log.New(&buf, "", 0)
48-
defer func() { DefaultLogger.Logger = originalLogger }() // Restaure l'ancien Logger
48+
defer func() { DefaultLogger.Logger = originalLogger }()
4949

5050
msg := "This is an info message"
5151
DefaultLogger.Info(msg)
@@ -105,14 +105,13 @@ func TestLogger_PrintBanner(t *testing.T) {
105105
version := "v1.0.0"
106106
isLatest := true
107107
DefaultLogger.PrintBanner(version, isLatest)
108-
defer func() { _ = w.Close() }()
108+
_ = w.Close()
109109

110110
var outBuf bytes.Buffer
111111
_, _ = outBuf.ReadFrom(r)
112112
os.Stdout = originalStdout
113113

114114
output := outBuf.String()
115-
116115
if !strings.Contains(output, version) || !strings.Contains(output, "latest") {
117116
t.Errorf("PrintBanner() output = %v, want version %v and 'latest'", output, version)
118117
}
@@ -121,14 +120,13 @@ func TestLogger_PrintBanner(t *testing.T) {
121120
os.Stdout = w
122121

123122
DefaultLogger.PrintBanner(version, false)
123+
_ = w.Close()
124124

125-
defer func() { _ = w.Close() }()
126125
outBuf.Reset()
127126
_, _ = outBuf.ReadFrom(r)
128127
os.Stdout = originalStdout
129128

130129
output = outBuf.String()
131-
132130
if !strings.Contains(output, "outdated") {
133131
t.Errorf("PrintBanner() output = %v, want 'outdated'", output)
134132
}

internal/utils/version.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ func CheckLatestVersion(currentVersion string) (string, bool) {
7070
return latest.String(), curr.Compare(latest) >= 0
7171
}
7272

73-
func GetPluginVersion(target, plugin string, _ int) string {
74-
httpClient := NewHTTPClient(10 * time.Second)
73+
func GetPluginVersion(target, plugin string, _ int, headers []string) string {
74+
httpClient := NewHTTPClient(10*time.Second, headers)
7575
return fetchVersionFromReadme(httpClient, target, plugin)
7676
}
7777

internal/utils/version_test.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ func TestGetPluginVersion(t *testing.T) {
7878
switch r.URL.Path {
7979
case "/wp-content/plugins/test-plugin/readme.txt":
8080
_, err = fmt.Fprintln(w, "Stable tag: 1.0.0")
81-
case "/wp-content/themes/test-theme/style.css":
82-
_, err = fmt.Fprintln(w, "Version: 2.3.4")
8381
default:
8482
http.NotFound(w, r)
8583
}
@@ -101,12 +99,6 @@ func TestGetPluginVersion(t *testing.T) {
10199
plugin: "test-plugin",
102100
expected: "1.0.0",
103101
},
104-
{
105-
name: "Plugin version from style.css",
106-
target: mockServer.URL,
107-
plugin: "test-theme",
108-
expected: "2.3.4",
109-
},
110102
{
111103
name: "Unknown plugin",
112104
target: mockServer.URL,
@@ -117,7 +109,7 @@ func TestGetPluginVersion(t *testing.T) {
117109

118110
for _, tt := range tests {
119111
t.Run(tt.name, func(t *testing.T) {
120-
got := GetPluginVersion(tt.target, tt.plugin, 2)
112+
got := GetPluginVersion(tt.target, tt.plugin, 2, nil)
121113
if got != tt.expected {
122114
t.Errorf("GetPluginVersion() = %v, want %v", got, tt.expected)
123115
}
@@ -133,7 +125,7 @@ func Test_fetchVersionFromReadme(t *testing.T) {
133125
}))
134126
defer mockServer.Close()
135127

136-
client := NewHTTPClient(5 * time.Second)
128+
client := NewHTTPClient(5*time.Second, nil)
137129
version := fetchVersionFromReadme(client, mockServer.URL, "sample")
138130
if version != "3.4.1" {
139131
t.Errorf("fetchVersionFromReadme() = %v, want %v", version, "3.4.1")

0 commit comments

Comments
 (0)