Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 112 additions & 84 deletions graphrag_sdk/chat_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Iterator
from graphrag_sdk.ontology import Ontology
from graphrag_sdk.steps.qa_step import QAStep
from graphrag_sdk.steps.stream_qa_step import StreamingQAStep
from graphrag_sdk.model_config import KnowledgeGraphModelConfig
from graphrag_sdk.steps.graph_query_step import GraphQueryGenerationStep

Expand All @@ -24,8 +23,19 @@ class ChatSession:
>>> from graphrag_sdk.model_config import KnowledgeGraphModelConfig
>>> model_config = KnowledgeGraphModelConfig.with_model(model)
>>> kg = KnowledgeGraph("test_kg", model_config, ontology)
>>> chat_session = kg.start_chat()
>>> chat_session.send_message("What is the capital of France?")
>>> session = kg.chat_session()
>>>
>>> # Full QA pipeline
>>> response = session.send_message("What is the capital of France?")
>>> print(response["question"]) # "What is the capital of France?"
>>> print(response["response"]) # "Paris"
>>> print(response["context"]) # Retrieved context
>>> print(response["cypher"]) # "MATCH (c:City)..."
>>>
>>> # Just generate Cypher query without QA
>>> context, cypher = session.generate_cypher_query("Who are the actors in The Matrix?")
>>> print(context) # Retrieved context
>>> print(cypher) # "MATCH (a:Actor)-[:ACTED_IN]->(m:Movie {title: 'The Matrix'}) RETURN a.name"
"""

def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, graph: Graph,
Expand All @@ -38,20 +48,25 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology,
model_config (KnowledgeGraphModelConfig): The model configuration.
ontology (Ontology): The ontology object.
graph (Graph): The graph object.
cypher_system_instruction (str): System instruction for cypher generation.
qa_system_instruction (str): System instruction for QA.
cypher_gen_prompt (str): Prompt template for cypher generation.
qa_prompt (str): Prompt template for QA.
cypher_gen_prompt_history (str): Prompt template for cypher generation with history.

Attributes:
model_config (KnowledgeGraphModelConfig): The model configuration.
ontology (Ontology): The ontology object.
graph (Graph): The graph object.
cypher_chat_session (CypherChatSession): The Cypher chat session object.
qa_chat_session (QAChatSession): The QA chat session object.
cypher_chat_session: The Cypher chat session object.
qa_chat_session: The QA chat session object.
"""
self.model_config = model_config
self.graph = graph
self.ontology = ontology

# Filter the ontology to remove unique and required attributes that are not needed for Q&A.
ontology_prompt = self.clean_ontology_for_prompt(ontology)
ontology_prompt = clean_ontology_for_prompt(ontology)

cypher_system_instruction = cypher_system_instruction.format(ontology=ontology_prompt)

Expand All @@ -65,6 +80,21 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology,
self.qa_chat_session = model_config.qa.start_chat(
qa_system_instruction
)

# Initialize steps once during construction
self.qa_step = QAStep(
chat_session=self.qa_chat_session,
qa_prompt=self.qa_prompt,
)
self.cypher_step = GraphQueryGenerationStep(
graph=self.graph,
chat_session=self.cypher_chat_session,
ontology=self.ontology,
last_answer=None, # Will be updated dynamically
cypher_prompt=self.cypher_prompt,
cypher_prompt_with_history=self.cypher_prompt_with_history
)

self.last_complete_response = {
"question": None,
"response": None,
Expand All @@ -75,26 +105,37 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology,
# Metadata to store additional information about the chat session (currently only last query execution time)
self.metadata = {"last_query_execution_time": None}

def _generate_cypher_query(self, message: str) -> tuple:
def _update_last_complete_response(self, response_dict: dict):
"""Update the last complete response in both the session and cypher step."""
self.last_complete_response = response_dict
self.cypher_step.last_answer = response_dict.get("response")

def generate_cypher_query(self, message: str) -> tuple:
"""
Generate a Cypher query for the given message.
Generate a Cypher query for the given message without running QA.

This method allows users to get just the Cypher query and context
without executing the full question-answering pipeline.

Args:
message (str): The message to generate a query for.

Returns:
tuple: A tuple containing (context, cypher)
tuple: A tuple containing (context, cypher) where:
- context (str): The extracted context from the graph
- cypher (str): The generated Cypher query
"""
cypher_step = GraphQueryGenerationStep(
graph=self.graph,
chat_session=self.cypher_chat_session,
ontology=self.ontology,
last_answer=self.last_complete_response["response"],
cypher_prompt=self.cypher_prompt,
cypher_prompt_with_history=self.cypher_prompt_with_history
)

(context, cypher, query_execution_time) = cypher_step.run(message)
# Update the last_answer for this query
self.cypher_step.last_answer = self.last_complete_response.get("response")

try:
(context, cypher, query_execution_time) = self.cypher_step.run(message)
except Exception:
# If there's an error, return empty context and cypher with error message
context = None
cypher = CYPHER_ERROR_RES
query_execution_time = None

self.metadata["last_query_execution_time"] = query_execution_time

return (context, cypher)
Expand All @@ -107,42 +148,37 @@ def send_message(self, message: str) -> dict:
message (str): The message to send.

Returns:
dict: The response to the message in the following format:
{"question": message,
"response": answer,
"context": context,
"cypher": cypher}
dict: A dictionary containing the response with keys:
- "question": The original question
- "response": The answer
- "context": The extracted context from the graph
- "cypher": The generated Cypher query
"""
(context, cypher) = self._generate_cypher_query(message)
(context, cypher) = self.generate_cypher_query(message)

# If the cypher is empty, return an error message
# If the cypher is empty, return an error response
if not cypher or len(cypher) == 0:
self.last_complete_response = {
return {
"question": message,
"response": CYPHER_ERROR_RES,
"context": None,
"cypher": None
}
return self.last_complete_response

qa_step = QAStep(
chat_session=self.qa_chat_session,
qa_prompt=self.qa_prompt,
)

answer = qa_step.run(message, cypher, context)

self.last_complete_response = {
"question": message,
"response": answer,
"context": context,
answer = self.qa_step.run(message, cypher, context)

response = {
"question": message,
"response": answer,
"context": context,
"cypher": cypher
}

return self.last_complete_response
self._update_last_complete_response(response)

return response

def send_message_stream(self, message: str) -> Iterator[str]:

"""
Sends a message to the chat session and streams the response.

Expand All @@ -152,60 +188,52 @@ def send_message_stream(self, message: str) -> Iterator[str]:
Yields:
str: Chunks of the response as they're generated.
"""
(context, cypher) = self._generate_cypher_query(message)
(context, cypher) = self.generate_cypher_query(message)

if not cypher or len(cypher) == 0:
# Stream the error message for consistency with successful responses
yield CYPHER_ERROR_RES

self.last_complete_response = {
"question": message,
"response": CYPHER_ERROR_RES,
"context": None,
"cypher": None
}
return

qa_step = StreamingQAStep(
chat_session=self.qa_chat_session,
qa_prompt=self.qa_prompt,
)

# Yield chunks of the response as they're generated
for chunk in qa_step.run(message, cypher, context):
for chunk in self.qa_step.run_stream(message, cypher, context):
yield chunk

# Set the last answer using chat history to ensure we have the complete response
self.last_complete_response = {
"question": message,
"response": qa_step.chat_session.get_chat_history()[-1]['content'],
"context": context,
# Set the last answer using chat history to ensure complete response
final_answer = self.qa_step.chat_session.get_chat_history()[-1]['content']

final_response = {
"question": message,
"response": final_answer,
"context": context,
"cypher": cypher
}

def clean_ontology_for_prompt(self, ontology: dict) -> str:
"""
Cleans the ontology by removing 'unique' and 'required' keys and prepares it for use in a prompt.
self._update_last_complete_response(final_response)

Args:
ontology (dict): The ontology to clean and transform.
def clean_ontology_for_prompt(ontology: Ontology) -> str:
"""
Cleans the ontology by removing 'unique' and 'required' keys and prepares it for use in a prompt.

Returns:
str: The cleaned ontology as a JSON string.
"""
# Convert the ontology object to a JSON.
ontology = ontology.to_json()

# Remove unique and required attributes from the ontology.
for entity in ontology["entities"]:
for attribute in entity["attributes"]:
del attribute['unique']
del attribute['required']

for relation in ontology["relations"]:
for attribute in relation["attributes"]:
del attribute['unique']
del attribute['required']

# Return the transformed ontology as a JSON string
return json.dumps(ontology)
Args:
ontology (Ontology): The ontology to clean and transform.

Returns:
str: The cleaned ontology as a JSON string.
"""
# Convert the ontology object to a JSON.
ontology_json = ontology.to_json()

# Remove unique and required attributes from the ontology.
for entity in ontology_json.get("entities", []):
for attribute in entity["attributes"]:
attribute.pop('unique', None)
attribute.pop('required', None)

for relation in ontology_json.get("relations", []):
for attribute in relation["attributes"]:
attribute.pop('unique', None)
attribute.pop('required', None)

# Return the transformed ontology as a JSON string
return json.dumps(ontology_json)
24 changes: 23 additions & 1 deletion graphrag_sdk/steps/qa_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional
from typing import Optional, Iterator
from graphrag_sdk.steps.Step import Step
from graphrag_sdk.models import GenerativeModelChatSession

Expand Down Expand Up @@ -50,3 +50,25 @@ def run(self, question: str, cypher: str, context: str) -> str:
qa_response = self.chat_session.send_message(qa_prompt)

return qa_response.text

def run_stream(self, question: str, cypher: str, context: str) -> Iterator[str]:
"""
Run the QA step and stream the response chunks.

Args:
question (str): The question being asked.
cypher (str): The Cypher query to run.
context (str): Context for the QA.

Returns:
Iterator[str]: A generator that yields response chunks.
"""
qa_prompt = self.qa_prompt.format(
context=context, cypher=cypher, question=question
)

logger.debug(f"QA Prompt (Stream): {qa_prompt}")

# Send the message and stream the response
for chunk in self.chat_session.send_message_stream(qa_prompt):
yield chunk
51 changes: 0 additions & 51 deletions graphrag_sdk/steps/stream_qa_step.py

This file was deleted.