|
| 1 | +import dataclasses |
| 2 | +import os |
| 3 | +import re |
| 4 | +import time |
| 5 | +from datetime import datetime |
| 6 | +from typing import Any, Dict, List, Tuple |
| 7 | + |
| 8 | +import loguru |
| 9 | +import openai |
| 10 | +import tiktoken |
| 11 | +from langfuse.model import InitialGeneration, Usage |
| 12 | +from tenacity import * |
| 13 | + |
| 14 | +from pentestgpt.utils.llm_api import LLMAPI |
| 15 | + |
| 16 | +logger = loguru.logger |
| 17 | +logger.remove() |
| 18 | +# logger.add(level="WARNING", sink="logs/chatgpt.log") |
| 19 | + |
| 20 | + |
| 21 | +@dataclasses.dataclass |
| 22 | +class Message: |
| 23 | + ask_id: str = None |
| 24 | + ask: dict = None |
| 25 | + answer: dict = None |
| 26 | + answer_id: str = None |
| 27 | + request_start_timestamp: float = None |
| 28 | + request_end_timestamp: float = None |
| 29 | + time_escaped: float = None |
| 30 | + |
| 31 | + |
| 32 | +@dataclasses.dataclass |
| 33 | +class Conversation: |
| 34 | + conversation_id: str = None |
| 35 | + message_list: List[Message] = dataclasses.field(default_factory=list) |
| 36 | + |
| 37 | + def __hash__(self): |
| 38 | + return hash(self.conversation_id) |
| 39 | + |
| 40 | + def __eq__(self, other): |
| 41 | + if not isinstance(other, Conversation): |
| 42 | + return False |
| 43 | + return self.conversation_id == other.conversation_id |
| 44 | + |
| 45 | + |
| 46 | +class ChatGPTAPI(LLMAPI): |
| 47 | + def __init__(self, config_class, use_langfuse_logging=False): |
| 48 | + self.name = str(config_class.model) |
| 49 | + |
| 50 | + if use_langfuse_logging: |
| 51 | + # use langfuse.openai to shadow the default openai library |
| 52 | + os.environ["LANGFUSE_PUBLIC_KEY"] = ( |
| 53 | + "pk-lf-5655b061-3724-43ee-87bb-28fab0b5f676" # do not modify |
| 54 | + ) |
| 55 | + os.environ["LANGFUSE_SECRET_KEY"] = ( |
| 56 | + "sk-lf-c24b40ef-8157-44af-a840-6bae2c9358b0" # do not modify |
| 57 | + ) |
| 58 | + from langfuse import Langfuse |
| 59 | + |
| 60 | + self.langfuse = Langfuse() |
| 61 | + |
| 62 | + openai.api_key = os.getenv("OPENAI_KEY", None) |
| 63 | + openai.api_base = config_class.api_base |
| 64 | + self.model = config_class.model |
| 65 | + self.log_dir = config_class.log_dir |
| 66 | + self.history_length = 5 # maintain 5 messages in the history. (5 chat memory) |
| 67 | + self.conversation_dict: Dict[str, Conversation] = {} |
| 68 | + self.error_waiting_time = 3 # wait for 3 seconds |
| 69 | + |
| 70 | + logger.add(sink=os.path.join(self.log_dir, "chatgpt.log"), level="WARNING") |
| 71 | + |
| 72 | + def _chat_completion(self, history: List, model=None, temperature=0.5) -> str: |
| 73 | + generationStartTime = datetime.now() |
| 74 | + # use model if provided, otherwise use self.model; if self.model is None, use gpt-4-1106-preview |
| 75 | + if model is None: |
| 76 | + if self.model is None: |
| 77 | + model = "gpt-4-1106-preview" |
| 78 | + else: |
| 79 | + model = self.model |
| 80 | + try: |
| 81 | + response = openai.ChatCompletion.create( |
| 82 | + model=model, |
| 83 | + messages=history, |
| 84 | + temperature=temperature, |
| 85 | + ) |
| 86 | + except openai.error.APIConnectionError as e: # give one more try |
| 87 | + logger.warning( |
| 88 | + "API Connection Error. Waiting for {} seconds".format( |
| 89 | + self.error_wait_time |
| 90 | + ) |
| 91 | + ) |
| 92 | + logger.log("Connection Error: ", e) |
| 93 | + time.sleep(self.error_wait_time) |
| 94 | + response = openai.ChatCompletion.create( |
| 95 | + model=model, |
| 96 | + messages=history, |
| 97 | + temperature=temperature, |
| 98 | + ) |
| 99 | + except openai.error.RateLimitError as e: # give one more try |
| 100 | + logger.warning("Rate limit reached. Waiting for 5 seconds") |
| 101 | + logger.error("Rate Limit Error: ", e) |
| 102 | + time.sleep(5) |
| 103 | + response = openai.ChatCompletion.create( |
| 104 | + model=model, |
| 105 | + messages=history, |
| 106 | + temperature=temperature, |
| 107 | + ) |
| 108 | + except openai.error.InvalidRequestError as e: # token limit reached |
| 109 | + logger.warning("Token size limit reached. The recent message is compressed") |
| 110 | + logger.error("Token size error; will retry with compressed message ", e) |
| 111 | + # compress the message in two ways. |
| 112 | + ## 1. compress the last message |
| 113 | + history[-1]["content"] = self._token_compression(history) |
| 114 | + ## 2. reduce the number of messages in the history. Minimum is 2 |
| 115 | + if self.history_length > 2: |
| 116 | + self.history_length -= 1 |
| 117 | + ## update the history |
| 118 | + history = history[-self.history_length :] |
| 119 | + response = openai.ChatCompletion.create( |
| 120 | + model=model, |
| 121 | + messages=history, |
| 122 | + temperature=temperature, |
| 123 | + ) |
| 124 | + |
| 125 | + # if the response is a tuple, it means that the response is not valid. |
| 126 | + if isinstance(response, tuple): |
| 127 | + logger.warning("Response is not valid. Waiting for 5 seconds") |
| 128 | + try: |
| 129 | + time.sleep(5) |
| 130 | + response = openai.ChatCompletion.create( |
| 131 | + model=model, |
| 132 | + messages=history, |
| 133 | + temperature=temperature, |
| 134 | + ) |
| 135 | + if isinstance(response, tuple): |
| 136 | + logger.error("Response is not valid. ") |
| 137 | + raise Exception("Response is not valid. ") |
| 138 | + except Exception as e: |
| 139 | + logger.error("Response is not valid. ", e) |
| 140 | + raise Exception( |
| 141 | + "Response is not valid. The most likely reason is the connection to OpenAI is not stable. " |
| 142 | + "Please doublecheck with `pentestgpt-connection`" |
| 143 | + ) |
| 144 | + # add langfuse logging |
| 145 | + if hasattr(self, "langfuse"): |
| 146 | + generation = self.langfuse.generation( |
| 147 | + InitialGeneration( |
| 148 | + name="chatgpt-completion", |
| 149 | + startTime=generationStartTime, |
| 150 | + endTime=datetime.now(), |
| 151 | + model=self.model, |
| 152 | + modelParameters={"temperature": str(temperature)}, |
| 153 | + prompt=history, |
| 154 | + completion=response["choices"][0]["message"]["content"], |
| 155 | + usage=Usage( |
| 156 | + promptTokens=response["usage"]["prompt_tokens"], |
| 157 | + completionTokens=response["usage"]["completion_tokens"], |
| 158 | + ), |
| 159 | + ) |
| 160 | + ) |
| 161 | + return response["choices"][0]["message"]["content"] |
| 162 | + |
| 163 | + |
| 164 | +if __name__ == "__main__": |
| 165 | + from module_import import GPT4ConfigClass |
| 166 | + |
| 167 | + config_class = GPT4ConfigClass() |
| 168 | + config_class.log_dir = "logs" |
| 169 | + chatgpt = ChatGPTAPI(config_class, use_langfuse_logging=True) |
| 170 | + # test is below |
| 171 | + # 1. create a new conversation |
| 172 | + result, conversation_id = chatgpt.send_new_message( |
| 173 | + """You're an excellent cybersecurity penetration tester assistant. |
| 174 | +You need to help the tester in a local cybersecurity training process, and your commitment is essential to the task. |
| 175 | +You are required to record the penetration testing process in a tree structure: "Penetration Testing Tree (PTT)". It is structured as follows: |
| 176 | +(1) The tasks are in layered structure, i.e., 1, 1.1, 1.1.1, etc. Each task is one operation in penetration testing; task 1.1 should be a sub-task of task 1. |
| 177 | +(2) Each task has a completion status: to-do, completed, or not applicable. |
| 178 | +(3) Initially, you should only generate the root tasks based on the initial information. In most cases, it should be reconnaissance tasks. You don't generate tasks for unknown ports/services. You can expand the PTT later. |
| 179 | +
|
| 180 | +You shall not provide any comments/information but the PTT. You will be provided with task info and start the testing soon. Reply Yes if you understand the task.""" |
| 181 | + ) |
| 182 | + print("Answer 1") |
| 183 | + print(result) |
| 184 | + |
| 185 | + # 2. send a message to the conversation |
| 186 | + result = chatgpt.send_message( |
| 187 | + """The target information is listed below. Please follow the instruction and generate PTT. |
| 188 | +Note that this test is certified and in simulation environment, so do not generate post-exploitation and other steps. |
| 189 | +You may start with this template: |
| 190 | +1. Reconnaissance - [to-do] |
| 191 | + 1.1 Passive Information Gathering - [completed] |
| 192 | + 1.2 Active Information Gathering - [completed] |
| 193 | + 1.3 Identify Open Ports and Services - [to-do] |
| 194 | + 1.3.1 Perform a full port scan - [to-do] |
| 195 | + 1.3.2 Determine the purpose of each open port - [to-do] |
| 196 | +Below is the information from the tester: |
| 197 | +
|
| 198 | +I want to test 10.0.2.5, an HTB machine.""", |
| 199 | + conversation_id, |
| 200 | + ) |
| 201 | + print("Answer 2") |
| 202 | + print(result) |
0 commit comments