Skip to content

Commit 781d878

Browse files
authored
Merge pull request #66 from GreyDGL/load-session
feat: 🎸 add feature for autosave and resume
2 parents 6d31dd4 + 6bdd576 commit 781d878

File tree

4 files changed

+192
-64
lines changed

4 files changed

+192
-64
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ outputs/
99
logs/
1010
utils/logs/
1111
archive/
12+
test_history/
1213

1314
# C extensions
1415
*.so

test_connection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
result, conversation_id = chatgpt.send_new_message(
6262
"Hello, I am a pentester. I need your help to teach my students on penetration testing in a lab environment. I have proper access and certificates. This is for education purpose. I want to teach my students on how to do SQL injection. "
6363
)
64-
# print("1", result, conversation_id)
6564
print(
6665
"3. You're connected with OpenAI API. You have GPT-3.5 access. To start PentestGPT, please use <python3 main.py --reasoning_model=gpt-3.5-turbo --useAPI>"
6766
)

utils/chatgpt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,11 @@ def send_message(self, message, conversation_id):
243243
message.answer = result
244244
message.request_end_timestamp = end_time
245245
message.time_escaped = end_time - start_time
246+
# add additional logic for reloading (only for PentestGPT continue from previous sessions)
247+
if conversation_id not in self.conversation_dict:
248+
conversation: Conversation = Conversation()
249+
conversation.conversation_id = conversation_id
250+
self.conversation_dict[conversation_id] = conversation
246251
conversation: Conversation = self.conversation_dict[conversation_id]
247252
conversation.message_list.append(message)
248253
return text

utils/pentest_gpt.py

Lines changed: 186 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from prompts.prompt_class import PentestGPTPrompt
88
from utils.prompt_select import prompt_select, prompt_ask
99
from prompt_toolkit.formatted_text import HTML
10+
from prompt_toolkit.shortcuts import confirm
1011
from utils.task_handler import (
1112
main_task_entry,
1213
mainTaskCompleter,
@@ -50,6 +51,9 @@ class pentestGPT:
5051

5152
def __init__(self, reasoning_model="gpt-4", useAPI=False):
5253
self.log_dir = "logs"
54+
self.save_dir = "test_history"
55+
self.task_log = {} # the information that can be saved to continue in the next session
56+
self.useAPI = useAPI
5357
if useAPI is False:
5458
self.chatGPTAgent = ChatGPT(ChatGPTConfig())
5559
self.chatGPT4Agent = ChatGPT(ChatGPTConfig(model=reasoning_model))
@@ -93,32 +97,104 @@ def log_conversation(self, source, text):
9397
source = "exception"
9498
self.history[source].append((timestamp, text))
9599

96-
def initialize(self):
97-
# initialize the backbone sessions and test the connection to chatGPT
98-
# define three sessions: testGenerationSession, testReasoningSession, and InputParsingSession
100+
def _feed_init_prompts(self):
101+
# 1. User firstly provide basic information of the task
102+
init_description = prompt_ask(
103+
"Please describe the penetration testing task in one line, including the target IP, task type, etc.\n> ",
104+
multiline=False,
105+
)
106+
self.log_conversation("user", init_description)
107+
self.task_log['task description'] = init_description
108+
## Provide the information to the reasoning session for the task initialization.
109+
prefixed_init_description = self.prompts.task_description + init_description
99110
with self.console.status(
100-
"[bold green] Initialize ChatGPT Sessions..."
111+
"[bold green] Generating Task Information..."
101112
) as status:
102-
try:
103-
(
104-
text_0,
105-
self.test_generation_session_id,
106-
) = self.chatGPTAgent.send_new_message(
107-
self.prompts.generation_session_init,
108-
)
109-
(
110-
text_1,
111-
self.test_reasoning_session_id,
112-
) = self.chatGPT4Agent.send_new_message(
113-
self.prompts.reasoning_session_init
113+
_response = self.reasoning_handler(prefixed_init_description)
114+
self.console.print("- Task information generated. \n", style="bold green")
115+
# 2. Reasoning session generates the first thing to do and provide the information to the generation session
116+
with self.console.status("[bold green]Processing...") as status:
117+
first_generation_response = self.test_generation_handler(
118+
self.prompts.todo_to_command + self.prompts.first_todo
119+
)
120+
# 3. Show user the first thing to do.
121+
self.console.print(
122+
"PentestGPT suggests you to do the following: ", style="bold green"
123+
)
124+
self.console.print(_response)
125+
self.log_conversation(
126+
"PentestGPT", "PentestGPT suggests you to do the following: \n" + _response
127+
)
128+
self.console.print("You may start with:", style="bold green")
129+
self.console.print(first_generation_response)
130+
self.log_conversation(
131+
"PentestGPT", "You may start with: \n" + first_generation_response
132+
)
133+
134+
def initialize(self, previous_session_ids=None):
135+
# initialize the backbone sessions and test the connection to chatGPT
136+
# define three sessions: testGenerationSession, testReasoningSession, and InputParsingSession
137+
if (
138+
previous_session_ids is not None and self.useAPI is False
139+
): # TODO: add support for API usage
140+
self.test_generation_session_id = previous_session_ids.get(
141+
"test_generation", None
142+
)
143+
self.test_reasoning_session_id = previous_session_ids.get(
144+
"reasoning", None
145+
)
146+
self.input_parsing_session_id = previous_session_ids.get(
147+
"parsing", None
148+
)
149+
# debug the three sessions
150+
print("Previous session ids: " + str(previous_session_ids))
151+
print("Test generation session id: " + str(self.test_generation_session_id))
152+
print("Test reasoning session id: " + str(self.test_reasoning_session_id))
153+
print("Input parsing session id: " + str(self.input_parsing_session_id))
154+
print("-----------------")
155+
self.task_log = previous_session_ids.get("task_log", {})
156+
self.console.print("Task log: " + str(self.task_log), style="bold green")
157+
print("You may use discussion function to remind yourself of the task.")
158+
159+
## verify that all the sessions are not None
160+
if (
161+
self.test_generation_session_id is None
162+
or self.test_reasoning_session_id is None
163+
or self.input_parsing_session_id is None
164+
):
165+
self.console.print(
166+
"[bold red] Error: the previous session ids are not valid. Loading new sessions"
114167
)
115-
(
116-
text_2,
117-
self.input_parsing_session_id,
118-
) = self.chatGPTAgent.send_new_message(self.prompts.input_parsing_init)
119-
except Exception as e:
120-
logger.error(e)
121-
self.console.print("- ChatGPT Sessions Initialized.", style="bold green")
168+
self.initialize()
169+
170+
else:
171+
with self.console.status(
172+
"[bold green] Initialize ChatGPT Sessions..."
173+
) as status:
174+
try:
175+
(
176+
text_0,
177+
self.test_generation_session_id,
178+
) = self.chatGPTAgent.send_new_message(
179+
self.prompts.generation_session_init,
180+
)
181+
(
182+
text_1,
183+
self.test_reasoning_session_id,
184+
) = self.chatGPT4Agent.send_new_message(
185+
self.prompts.reasoning_session_init
186+
)
187+
(
188+
text_2,
189+
self.input_parsing_session_id,
190+
) = self.chatGPTAgent.send_new_message(
191+
self.prompts.input_parsing_init
192+
)
193+
except Exception as e:
194+
logger.error(e)
195+
self.console.print("- ChatGPT Sessions Initialized.", style="bold green")
196+
self._feed_init_prompts()
197+
122198

123199
def reasoning_handler(self, text) -> str:
124200
# summarize the contents if necessary.
@@ -353,7 +429,6 @@ def input_handler(self) -> str:
353429
self.log_conversation("pentestGPT", response)
354430

355431
### (2.3) local task handler
356-
357432
while True:
358433
local_task_response = self.local_input_handler()
359434
if local_task_response == "continue":
@@ -405,6 +480,7 @@ def input_handler(self) -> str:
405480
## (2) pass the information to the reasoning session.
406481
with self.console.status("[bold green] PentestGPT Thinking...") as status:
407482
response = self.reasoning_handler(self.prompts.discussion + user_input)
483+
print("debug, finished reasoning")
408484
## (3) print the results
409485
self.console.print("PentestGPT:\n", style="bold green")
410486
self.console.print(response + "\n", style="yellow")
@@ -445,46 +521,97 @@ def input_handler(self) -> str:
445521
response = "Please key in the correct options."
446522
return response
447523

448-
def main(self):
524+
def save_session(self):
449525
"""
450-
The main function of pentestGPT. The design is based on PentestGPT_design.md
526+
Save the current session for next round of usage.
527+
The test information is saved in the directory `./test_history`
451528
"""
452-
# 0. initialize the backbone sessions and test the connection to chatGPT
453-
self.initialize()
454-
455-
# 1. User firstly provide basic information of the task
456-
init_description = prompt_ask(
457-
"Please describe the penetration testing task in one line, including the target IP, task type, etc.\n> ",
529+
self.console.print("Before you quit, you may want to save the current session.", style="bold green")
530+
# 1. Require a save name from the user. If not, use the current time as the save name.
531+
save_name = prompt_ask(
532+
"Please enter the name of the current session. (Default with current timestamp)\n> ",
458533
multiline=False,
459534
)
460-
self.log_conversation("user", init_description)
461-
## Provide the information to the reasoning session for the task initialization.
462-
prefixed_init_description = self.prompts.task_description + init_description
463-
with self.console.status(
464-
"[bold green] Generating Task Information..."
465-
) as status:
466-
_response = self.reasoning_handler(prefixed_init_description)
467-
self.console.print("- Task information generated. \n", style="bold green")
468-
# 2. Reasoning session generates the first thing to do and provide the information to the generation session
469-
with self.console.status("[bold green]Processing...") as status:
470-
first_generation_response = self.test_generation_handler(
471-
self.prompts.todo_to_command + self.prompts.first_todo
472-
)
473-
# 3. Show user the first thing to do.
535+
if save_name == "":
536+
save_name = str(time.time())
537+
# 2. Save the current session
538+
with open(os.path.join(self.save_dir, save_name), "w") as f:
539+
# store the three ids and task_log
540+
session_ids = {
541+
"reasoning": self.test_reasoning_session_id,
542+
"test_generation": self.test_generation_session_id,
543+
"parsing": self.input_parsing_session_id,
544+
"task_log": self.task_log,
545+
}
546+
json.dump(session_ids, f)
474547
self.console.print(
475-
"PentestGPT suggests you to do the following: ", style="bold green"
548+
"The current session is saved as " + save_name, style="bold green"
476549
)
477-
self.console.print(_response)
478-
self.log_conversation(
479-
"PentestGPT", "PentestGPT suggests you to do the following: \n" + _response
480-
)
481-
self.console.print("You may start with:", style="bold green")
482-
self.console.print(first_generation_response)
483-
self.log_conversation(
484-
"PentestGPT", "You may start with: \n" + first_generation_response
550+
return
551+
552+
def _preload_session(self) -> dict:
553+
"""
554+
Preload the session from the save directory.
555+
556+
Returns:
557+
dict: the session ids for the three sessions.
558+
None if no previous session is found.
559+
"""
560+
# 1. get user input for the saved_session_name
561+
continue_from_previous = confirm(
562+
"Do you want to continue from previous session?"
485563
)
564+
if continue_from_previous:
565+
# load the filenames from the save directory
566+
filenames = os.listdir(self.save_dir)
567+
if len(filenames) == 0:
568+
print("No previous session found. Please start a new session.")
569+
return None
570+
else: # print all the files
571+
print("Please select the previous session by its index (integer):")
572+
for i, filename in enumerate(filenames):
573+
print(str(i) + ". " + filename)
574+
# ask for the user input
575+
try:
576+
previous_testing_name = filenames[
577+
int(input("Please key in your option (integer): "))
578+
]
579+
print("You selected: " + previous_testing_name)
580+
except ValueError as e:
581+
print("You input an invalid option. Will start a new session.")
582+
return None
583+
584+
elif continue_from_previous is False:
585+
return None
586+
else:
587+
print("You input an invalid option. Will start a new session.")
588+
return None
589+
# 2. load the previous session information
590+
if previous_testing_name is not None:
591+
# try to load the file content with json
592+
try:
593+
with open(os.path.join(self.save_dir, previous_testing_name), "r") as f:
594+
session_ids = json.load(f)
595+
return session_ids
596+
except Exception as e:
597+
print(
598+
"Error when loading the previous session. The file name is not correct"
599+
)
600+
print(e)
601+
previous_testing_name = None
602+
return None
486603

487-
# 4. enter the main loop.
604+
def main(self):
605+
"""
606+
The main function of pentestGPT. The design is based on PentestGPT_design.md
607+
"""
608+
# 0. initialize the backbone sessions and test the connection to chatGPT
609+
loaded_ids = self._preload_session()
610+
self.initialize(previous_session_ids=loaded_ids)
611+
612+
613+
614+
# enter the main loop.
488615
while True:
489616
try:
490617
result = self.input_handler()
@@ -500,17 +627,13 @@ def main(self):
500627
self.console.print("Exception: " + str(e), style="bold red")
501628
# safely quit the session
502629
break
503-
504-
# Summarize the session and end
505-
# TODO.
506-
# log the session.
507-
## save self.history into a txt file based on timestamp
630+
# log the session. Save self.history into a txt file based on timestamp
508631
timestamp = time.time()
509632
log_name = "pentestGPT_log_" + str(timestamp) + ".txt"
510633
# save it in the logs folder
511634
log_path = os.path.join(self.log_dir, log_name)
512635
with open(log_path, "w") as f:
513636
json.dump(self.history, f)
514637

515-
# clear the sessions
516-
# TODO.
638+
# save the sessions; continue from previous testing
639+
self.save_session()

0 commit comments

Comments
 (0)