diff --git a/graphrag_sdk/chat_session.py b/graphrag_sdk/chat_session.py index cda7ed2..c0c7098 100644 --- a/graphrag_sdk/chat_session.py +++ b/graphrag_sdk/chat_session.py @@ -3,7 +3,6 @@ from typing import Iterator from graphrag_sdk.ontology import Ontology from graphrag_sdk.steps.qa_step import QAStep -from graphrag_sdk.steps.stream_qa_step import StreamingQAStep from graphrag_sdk.model_config import KnowledgeGraphModelConfig from graphrag_sdk.steps.graph_query_step import GraphQueryGenerationStep @@ -24,8 +23,19 @@ class ChatSession: >>> from graphrag_sdk.model_config import KnowledgeGraphModelConfig >>> model_config = KnowledgeGraphModelConfig.with_model(model) >>> kg = KnowledgeGraph("test_kg", model_config, ontology) - >>> chat_session = kg.start_chat() - >>> chat_session.send_message("What is the capital of France?") + >>> session = kg.chat_session() + >>> + >>> # Full QA pipeline + >>> response = session.send_message("What is the capital of France?") + >>> print(response["question"]) # "What is the capital of France?" + >>> print(response["response"]) # "Paris" + >>> print(response["context"]) # Retrieved context + >>> print(response["cypher"]) # "MATCH (c:City)..." + >>> + >>> # Just generate Cypher query without QA + >>> context, cypher = session.generate_cypher_query("Who are the actors in The Matrix?") + >>> print(context) # Retrieved context + >>> print(cypher) # "MATCH (a:Actor)-[:ACTED_IN]->(m:Movie {title: 'The Matrix'}) RETURN a.name" """ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, graph: Graph, @@ -38,20 +48,25 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, model_config (KnowledgeGraphModelConfig): The model configuration. ontology (Ontology): The ontology object. graph (Graph): The graph object. + cypher_system_instruction (str): System instruction for cypher generation. + qa_system_instruction (str): System instruction for QA. + cypher_gen_prompt (str): Prompt template for cypher generation. + qa_prompt (str): Prompt template for QA. + cypher_gen_prompt_history (str): Prompt template for cypher generation with history. Attributes: model_config (KnowledgeGraphModelConfig): The model configuration. ontology (Ontology): The ontology object. graph (Graph): The graph object. - cypher_chat_session (CypherChatSession): The Cypher chat session object. - qa_chat_session (QAChatSession): The QA chat session object. + cypher_chat_session: The Cypher chat session object. + qa_chat_session: The QA chat session object. """ self.model_config = model_config self.graph = graph self.ontology = ontology # Filter the ontology to remove unique and required attributes that are not needed for Q&A. - ontology_prompt = self.clean_ontology_for_prompt(ontology) + ontology_prompt = clean_ontology_for_prompt(ontology) cypher_system_instruction = cypher_system_instruction.format(ontology=ontology_prompt) @@ -65,6 +80,21 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, self.qa_chat_session = model_config.qa.start_chat( qa_system_instruction ) + + # Initialize steps once during construction + self.qa_step = QAStep( + chat_session=self.qa_chat_session, + qa_prompt=self.qa_prompt, + ) + self.cypher_step = GraphQueryGenerationStep( + graph=self.graph, + chat_session=self.cypher_chat_session, + ontology=self.ontology, + last_answer=None, # Will be updated dynamically + cypher_prompt=self.cypher_prompt, + cypher_prompt_with_history=self.cypher_prompt_with_history + ) + self.last_complete_response = { "question": None, "response": None, @@ -75,26 +105,37 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, # Metadata to store additional information about the chat session (currently only last query execution time) self.metadata = {"last_query_execution_time": None} - def _generate_cypher_query(self, message: str) -> tuple: + def _update_last_complete_response(self, response_dict: dict): + """Update the last complete response in both the session and cypher step.""" + self.last_complete_response = response_dict + self.cypher_step.last_answer = response_dict.get("response") + + def generate_cypher_query(self, message: str) -> tuple: """ - Generate a Cypher query for the given message. + Generate a Cypher query for the given message without running QA. + + This method allows users to get just the Cypher query and context + without executing the full question-answering pipeline. Args: message (str): The message to generate a query for. Returns: - tuple: A tuple containing (context, cypher) + tuple: A tuple containing (context, cypher) where: + - context (str): The extracted context from the graph + - cypher (str): The generated Cypher query """ - cypher_step = GraphQueryGenerationStep( - graph=self.graph, - chat_session=self.cypher_chat_session, - ontology=self.ontology, - last_answer=self.last_complete_response["response"], - cypher_prompt=self.cypher_prompt, - cypher_prompt_with_history=self.cypher_prompt_with_history - ) - - (context, cypher, query_execution_time) = cypher_step.run(message) + # Update the last_answer for this query + self.cypher_step.last_answer = self.last_complete_response.get("response") + + try: + (context, cypher, query_execution_time) = self.cypher_step.run(message) + except Exception: + # If there's an error, return empty context and cypher with error message + context = None + cypher = CYPHER_ERROR_RES + query_execution_time = None + self.metadata["last_query_execution_time"] = query_execution_time return (context, cypher) @@ -107,42 +148,37 @@ def send_message(self, message: str) -> dict: message (str): The message to send. Returns: - dict: The response to the message in the following format: - {"question": message, - "response": answer, - "context": context, - "cypher": cypher} + dict: A dictionary containing the response with keys: + - "question": The original question + - "response": The answer + - "context": The extracted context from the graph + - "cypher": The generated Cypher query """ - (context, cypher) = self._generate_cypher_query(message) + (context, cypher) = self.generate_cypher_query(message) - # If the cypher is empty, return an error message + # If the cypher is empty, return an error response if not cypher or len(cypher) == 0: - self.last_complete_response = { + return { "question": message, "response": CYPHER_ERROR_RES, "context": None, "cypher": None } - return self.last_complete_response - qa_step = QAStep( - chat_session=self.qa_chat_session, - qa_prompt=self.qa_prompt, - ) - - answer = qa_step.run(message, cypher, context) - - self.last_complete_response = { - "question": message, - "response": answer, - "context": context, + answer = self.qa_step.run(message, cypher, context) + + response = { + "question": message, + "response": answer, + "context": context, "cypher": cypher } - return self.last_complete_response + self._update_last_complete_response(response) + + return response def send_message_stream(self, message: str) -> Iterator[str]: - """ Sends a message to the chat session and streams the response. @@ -152,60 +188,52 @@ def send_message_stream(self, message: str) -> Iterator[str]: Yields: str: Chunks of the response as they're generated. """ - (context, cypher) = self._generate_cypher_query(message) + (context, cypher) = self.generate_cypher_query(message) if not cypher or len(cypher) == 0: # Stream the error message for consistency with successful responses yield CYPHER_ERROR_RES - - self.last_complete_response = { - "question": message, - "response": CYPHER_ERROR_RES, - "context": None, - "cypher": None - } return - qa_step = StreamingQAStep( - chat_session=self.qa_chat_session, - qa_prompt=self.qa_prompt, - ) - # Yield chunks of the response as they're generated - for chunk in qa_step.run(message, cypher, context): + for chunk in self.qa_step.run_stream(message, cypher, context): yield chunk - # Set the last answer using chat history to ensure we have the complete response - self.last_complete_response = { - "question": message, - "response": qa_step.chat_session.get_chat_history()[-1]['content'], - "context": context, + # Set the last answer using chat history to ensure complete response + final_answer = self.qa_step.chat_session.get_chat_history()[-1]['content'] + + final_response = { + "question": message, + "response": final_answer, + "context": context, "cypher": cypher } - def clean_ontology_for_prompt(self, ontology: dict) -> str: - """ - Cleans the ontology by removing 'unique' and 'required' keys and prepares it for use in a prompt. + self._update_last_complete_response(final_response) - Args: - ontology (dict): The ontology to clean and transform. +def clean_ontology_for_prompt(ontology: Ontology) -> str: + """ + Cleans the ontology by removing 'unique' and 'required' keys and prepares it for use in a prompt. - Returns: - str: The cleaned ontology as a JSON string. - """ - # Convert the ontology object to a JSON. - ontology = ontology.to_json() - - # Remove unique and required attributes from the ontology. - for entity in ontology["entities"]: - for attribute in entity["attributes"]: - del attribute['unique'] - del attribute['required'] - - for relation in ontology["relations"]: - for attribute in relation["attributes"]: - del attribute['unique'] - del attribute['required'] - - # Return the transformed ontology as a JSON string - return json.dumps(ontology) \ No newline at end of file + Args: + ontology (Ontology): The ontology to clean and transform. + + Returns: + str: The cleaned ontology as a JSON string. + """ + # Convert the ontology object to a JSON. + ontology_json = ontology.to_json() + + # Remove unique and required attributes from the ontology. + for entity in ontology_json.get("entities", []): + for attribute in entity["attributes"]: + attribute.pop('unique', None) + attribute.pop('required', None) + + for relation in ontology_json.get("relations", []): + for attribute in relation["attributes"]: + attribute.pop('unique', None) + attribute.pop('required', None) + + # Return the transformed ontology as a JSON string + return json.dumps(ontology_json) \ No newline at end of file diff --git a/graphrag_sdk/steps/qa_step.py b/graphrag_sdk/steps/qa_step.py index bf25e5c..f0f18a7 100644 --- a/graphrag_sdk/steps/qa_step.py +++ b/graphrag_sdk/steps/qa_step.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import Optional, Iterator from graphrag_sdk.steps.Step import Step from graphrag_sdk.models import GenerativeModelChatSession @@ -50,3 +50,25 @@ def run(self, question: str, cypher: str, context: str) -> str: qa_response = self.chat_session.send_message(qa_prompt) return qa_response.text + + def run_stream(self, question: str, cypher: str, context: str) -> Iterator[str]: + """ + Run the QA step and stream the response chunks. + + Args: + question (str): The question being asked. + cypher (str): The Cypher query to run. + context (str): Context for the QA. + + Returns: + Iterator[str]: A generator that yields response chunks. + """ + qa_prompt = self.qa_prompt.format( + context=context, cypher=cypher, question=question + ) + + logger.debug(f"QA Prompt (Stream): {qa_prompt}") + + # Send the message and stream the response + for chunk in self.chat_session.send_message_stream(qa_prompt): + yield chunk diff --git a/graphrag_sdk/steps/stream_qa_step.py b/graphrag_sdk/steps/stream_qa_step.py deleted file mode 100644 index ab2d2ba..0000000 --- a/graphrag_sdk/steps/stream_qa_step.py +++ /dev/null @@ -1,51 +0,0 @@ -import logging -from typing import Optional, Iterator -from graphrag_sdk.steps.Step import Step -from graphrag_sdk.models import GenerativeModelChatSession - - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -class StreamingQAStep(Step): - """ - QA Step that supports streaming responses - """ - - def __init__( - self, - chat_session: GenerativeModelChatSession, - config: Optional[dict] = None, - qa_prompt: Optional[str] = None, - ) -> None: - """ - Initialize the QA Step. - - Args: - chat_session (GenerativeModelChatSession): The chat session for handling the QA. - config (Optional[dict]): Optional configuration for the step. - qa_prompt (Optional[str]): The prompt template for question answering. - """ - self.config = config or {} - self.chat_session = chat_session - self.qa_prompt = qa_prompt - - def run(self, question: str, cypher: str, context: str) -> Iterator[str]: - """ - Run the QA step and stream the response chunks. - - Args: - question (str): The question being asked. - cypher (str): The Cypher query to run. - context (str): Context for the QA. - - Returns: - Iterator[str]: A generator that yields response chunks. - """ - qa_prompt = self.qa_prompt.format( - context=context, cypher=cypher, question=question - ) - logger.debug(f"QA Prompt: {qa_prompt}") - # Send the message and stream the response - for chunk in self.chat_session.send_message_stream(qa_prompt): - yield chunk \ No newline at end of file