diff --git a/README.md b/README.md index 020bda9..316e273 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,6 @@ Build your AI agents in three lines of code! * [Python](https://python.org) - Programming Language * [OpenAI](https://openai.com) - AI Model Provider * [MongoDB](https://mongodb.com) - Conversational History (optional) -* [Zep Cloud](https://getzep.com) - Conversational Memory (optional) * [Pinecone](https://pinecone.io) - Knowledge Base (optional) * [AgentiPy](https://agentipy.fun) - Solana Ecosystem (optional) * [Zapier](https://zapier.com) - App Integrations (optional) @@ -82,7 +81,7 @@ Build your AI agents in three lines of code! **OpenAI** * [gpt-4.1](https://platform.openai.com/docs/models/gpt-4.1) (agent - can be overridden) * [gpt-4.1-nano](https://platform.openai.com/docs/models/gpt-4.1-nano) (router) -* [text-embedding-3-large](https://platform.openai.com/docs/models/text-embedding-3-large) (embedding) +* [text-embedding-3-small](https://platform.openai.com/docs/models/text-embedding-3-small) (embedding) * [tts-1](https://platform.openai.com/docs/models/tts-1) (audio TTS) * [gpt-4o-mini-transcribe](https://platform.openai.com/docs/models/gpt-4o-mini-transcribe) (audio transcription) * [gpt-image-1](https://platform.openai.com/docs/models/gpt-image-1) (image generation - can be overridden) @@ -377,16 +376,6 @@ config = { } ``` -### Conversational Memory - -```python -config = { - "zep": { - "api_key": "your-zep-cloud-api-key", - }, -} -``` - ### Observability and Tracing ```python @@ -421,9 +410,7 @@ config = { } ``` -### Knowledge Base - -The Knowledge Base (KB) is meant to store text values and/or PDFs (extracts text) - can handle very large PDFs. +### Knowledge Base & Conversational Memory ```python config = { diff --git a/docs/index.rst b/docs/index.rst index 712e9c0..c2798f7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -310,17 +310,6 @@ Conversational History - Optional }, } -Conversational Memory - Optional -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: python - - config = { - "zep": { - "api_key": "your-zep-api-key", - }, - } - Observability and Tracing - Optional ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -361,7 +350,7 @@ Gemini - Optional } -Knowledge Base - Optional +Knowledge Base & Conversational Memory - Optional ~~~~~~~~~~~~~~~~~~~~~~~~~~~ The Knowledge Base (KB) is meant to store text values and/or small PDFs. diff --git a/poetry.lock b/poetry.lock index 55a7714..7f53664 100644 --- a/poetry.lock +++ b/poetry.lock @@ -765,14 +765,14 @@ files = [ [[package]] name = "fsspec" -version = "2025.3.2" +version = "2025.5.0" description = "File-system specification" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "fsspec-2025.3.2-py3-none-any.whl", hash = "sha256:2daf8dc3d1dfa65b6aa37748d112773a7a08416f6c70d96b264c96476ecaf711"}, - {file = "fsspec-2025.3.2.tar.gz", hash = "sha256:e52c77ef398680bbd6a98c0e628fbc469491282981209907bbc8aea76a04fdc6"}, + {file = "fsspec-2025.5.0-py3-none-any.whl", hash = "sha256:0ca253eca6b5333d8a2b8bd98c7326fe821f1f0fdbd34e1b445bddde8e804c95"}, + {file = "fsspec-2025.5.0.tar.gz", hash = "sha256:e4f4623bb6221f7407fd695cc535d1f857a077eb247580f4ada34f5dc25fd5c8"}, ] [package.extras] @@ -1997,14 +1997,14 @@ xmp = ["defusedxml"] [[package]] name = "pinecone" -version = "6.0.2" +version = "7.0.0" description = "Pinecone client and SDK" optional = false python-versions = "<4.0,>=3.9" groups = ["main"] files = [ - {file = "pinecone-6.0.2-py3-none-any.whl", hash = "sha256:a85fa36d7d1451e7b7563ccfc7e3e2dadd39b33e5d53b2882468db8514ab8847"}, - {file = "pinecone-6.0.2.tar.gz", hash = "sha256:9c2e74be8b3abe76909da9b4dae61bced49aade51f6fc39b87edb97a1f8df0e4"}, + {file = "pinecone-7.0.0-py3-none-any.whl", hash = "sha256:59ce603c84545f82c58a306697d744a56778a4f6dd1dc195a015e6910ea2fc7b"}, + {file = "pinecone-7.0.0.tar.gz", hash = "sha256:8a4c4c12a7ee2e71b79781ec9df3abcfa5f352fd9217c777c3f2c7db2c0870f0"}, ] [package.dependencies] @@ -2015,7 +2015,7 @@ typing-extensions = ">=3.7.4" urllib3 = {version = ">=1.26.5", markers = "python_version >= \"3.12\" and python_version < \"4.0\""} [package.extras] -asyncio = ["aiohttp (>=3.9.0)"] +asyncio = ["aiohttp (>=3.9.0)", "aiohttp-retry (>=2.9.1,<3.0.0)"] grpc = ["googleapis-common-protos (>=1.66.0)", "grpcio (>=1.44.0) ; python_version >= \"3.8\" and python_version < \"3.11\"", "grpcio (>=1.59.0) ; python_version >= \"3.11\" and python_version < \"4.0\"", "grpcio (>=1.68.0) ; python_version >= \"3.13\" and python_version < \"4.0\"", "lz4 (>=3.1.3)", "protobuf (>=5.29,<6.0)", "protoc-gen-openapiv2 (>=0.0.1,<0.0.2)"] [[package]] @@ -3948,23 +3948,6 @@ idna = ">=2.0" multidict = ">=4.0" propcache = ">=0.2.1" -[[package]] -name = "zep-cloud" -version = "2.12.3" -description = "" -optional = false -python-versions = "<4.0,>=3.9.0" -groups = ["main"] -files = [ - {file = "zep_cloud-2.12.3-py3-none-any.whl", hash = "sha256:20b5e71b22f65e20fdd502c26bdf22878985db340c84b10a3f4e5b60b967eff9"}, - {file = "zep_cloud-2.12.3.tar.gz", hash = "sha256:0b98b66b1ed26a5911f3cfb53a56419c9dddbb7a417967277429eb348288b515"}, -] - -[package.dependencies] -httpx = ">=0.21.2" -pydantic = ">=1.9.2" -typing_extensions = ">=4.0.0" - [[package]] name = "zipp" version = "3.21.0" @@ -3988,4 +3971,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "9cd8ef48916bd232ae69a364e052854e8cddb8b4a649e4d497c2c389a96e9ce7" +content-hash = "637703379f106b9c4b5a90ee7dbf94960a3c860e78af0f348d30d41831bc4fbb" diff --git a/pyproject.toml b/pyproject.toml index f11c404..96b025a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "solana-agent" -version = "29.1.4" +version = "30.0.0" description = "AI Agents for Solana" authors = ["Bevan Hunt "] license = "MIT" @@ -27,9 +27,8 @@ python = ">=3.12,<4.0" openai = "1.79.0" pydantic = ">=2" pymongo = "4.13.0" -zep-cloud = "2.12.3" instructor = "1.8.2" -pinecone = "6.0.2" +pinecone = "7.0.0" llama-index-core = "0.12.37" llama-index-embeddings-openai = "0.3.1" pypdf = "5.5.0" @@ -50,7 +49,7 @@ sphinx-rtd-theme = "^3.0.2" myst-parser = "^4.0.1" sphinx-autobuild = "^2024.10.3" mongomock = "^4.3.0" -ruff = "^0.11.9" +ruff = "^0.11.10" [tool.poetry.scripts] solana-agent = "solana_agent.cli:app" diff --git a/solana_agent/adapters/mongodb_graph_adapter.py b/solana_agent/adapters/mongodb_graph_adapter.py new file mode 100644 index 0000000..f49f495 --- /dev/null +++ b/solana_agent/adapters/mongodb_graph_adapter.py @@ -0,0 +1,67 @@ +import uuid +from typing import Dict, Any, List, Optional +from solana_agent.interfaces.providers.graph_storage import GraphStorageProvider +from solana_agent.adapters.mongodb_adapter import MongoDBAdapter + + +class MongoDBGraphAdapter(GraphStorageProvider): + def __init__( + self, + mongo_adapter: MongoDBAdapter, + node_collection: str = "graph_nodes", + edge_collection: str = "graph_edges", + ): + self.mongo = mongo_adapter + self.node_collection = node_collection + self.edge_collection = edge_collection + + async def add_node(self, node: Dict[str, Any]) -> str: + node = dict(node) + node["uuid"] = node.get("uuid", str(uuid.uuid4())) + self.mongo.insert_one(self.node_collection, node) + return node["uuid"] + + async def add_edge(self, edge: Dict[str, Any]) -> str: + edge = dict(edge) + edge["uuid"] = edge.get("uuid", str(uuid.uuid4())) + return self.mongo.insert_one(self.edge_collection, edge) + + async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]: + return self.mongo.find_one(self.node_collection, {"uuid": node_id}) + + async def get_edges( + self, node_id: str, direction: str = "both" + ) -> List[Dict[str, Any]]: + if direction == "out": + query = {"source": node_id} + elif direction == "in": + query = {"target": node_id} + else: + query = {"$or": [{"source": node_id}, {"target": node_id}]} + return self.mongo.find(self.edge_collection, query) + + async def find_neighbors( + self, node_id: str, depth: int = 1 + ) -> List[Dict[str, Any]]: + neighbors = set() + current = {node_id} + for _ in range(depth): + edges = await self.get_edges(list(current)[0]) + for edge in edges: + neighbors.add(edge.get("source")) + neighbors.add(edge.get("target")) + current = neighbors + neighbors.discard(node_id) + return [await self.get_node(nid) for nid in neighbors if nid] + + async def temporal_query( + self, node_id: str, start_time: Optional[str], end_time: Optional[str] + ) -> List[Dict[str, Any]]: + query = {"$or": [{"source": node_id}, {"target": node_id}]} + if start_time or end_time: + query["timestamp"] = {} + if start_time: + query["timestamp"]["$gte"] = start_time + if end_time: + query["timestamp"]["$lte"] = end_time + return self.mongo.find(self.edge_collection, query) diff --git a/solana_agent/adapters/openai_adapter.py b/solana_agent/adapters/openai_adapter.py index 50e7202..3ca4cc5 100644 --- a/solana_agent/adapters/openai_adapter.py +++ b/solana_agent/adapters/openai_adapter.py @@ -36,8 +36,8 @@ DEFAULT_CHAT_MODEL = "gpt-4.1" DEFAULT_VISION_MODEL = "gpt-4.1" DEFAULT_PARSE_MODEL = "gpt-4.1-nano" -DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large" -DEFAULT_EMBEDDING_DIMENSIONS = 3072 +DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" +DEFAULT_EMBEDDING_DIMENSIONS = 1536 DEFAULT_TRANSCRIPTION_MODEL = "gpt-4o-mini-transcribe" DEFAULT_TTS_MODEL = "tts-1" @@ -526,7 +526,7 @@ async def embed_text( Args: text: The text to embed. - model: The embedding model to use (defaults to text-embedding-3-large). + model: The embedding model to use (defaults to text-embedding-3-small). dimensions: Desired output dimensions for the embedding. Returns: diff --git a/solana_agent/adapters/pinecone_adapter.py b/solana_agent/adapters/pinecone_adapter.py index 221d87c..47f5429 100644 --- a/solana_agent/adapters/pinecone_adapter.py +++ b/solana_agent/adapters/pinecone_adapter.py @@ -33,7 +33,7 @@ def __init__( self, api_key: Optional[str] = None, index_name: Optional[str] = None, - # Default for OpenAI text-embedding-3-large, MUST match external embedder + # Default for OpenAI text-embedding-3-small, MUST match external embedder embedding_dimensions: int = 3072, cloud_provider: str = "aws", region: str = "us-east-1", @@ -41,7 +41,7 @@ def __init__( create_index_if_not_exists: bool = True, # Reranking Config use_reranking: bool = False, - rerank_model: Optional[PineconeRerankModel] = None, + rerank_model: Optional[PineconeRerankModel] = "cohere-rerank-3.5", rerank_top_k: int = 3, # Final number of results after reranking # Multiplier for initial fetch before rerank initial_query_top_k_multiplier: int = 5, @@ -371,7 +371,7 @@ async def query_and_rerank( "parameters": rerank_params, } - rerank_response = await self.pinecone.rerank(**rerank_request) + rerank_response = await self.pinecone.inference.rerank(**rerank_request) # 4. Process Reranked Results reranked_results = [] diff --git a/solana_agent/factories/agent_factory.py b/solana_agent/factories/agent_factory.py index 33bc282..04bd976 100644 --- a/solana_agent/factories/agent_factory.py +++ b/solana_agent/factories/agent_factory.py @@ -10,11 +10,13 @@ from typing import Dict, Any, List # Service imports +from solana_agent.adapters.mongodb_graph_adapter import MongoDBGraphAdapter from solana_agent.adapters.pinecone_adapter import PineconeAdapter from solana_agent.interfaces.guardrails.guardrails import ( InputGuardrail, OutputGuardrail, ) +from solana_agent.services.graph_memory import GraphMemoryService from solana_agent.services.query import QueryService from solana_agent.services.agent import AgentService from solana_agent.services.routing import RoutingService @@ -134,21 +136,9 @@ def create_from_config(config: Dict[str, Any]) -> QueryService: ) # Create repositories - memory_provider = None - - if "zep" in config and "mongo" in config: - memory_provider = MemoryRepository( - mongo_adapter=db_adapter, zep_api_key=config["zep"].get("api_key") - ) - - if "mongo" in config and "zep" not in config: - memory_provider = MemoryRepository(mongo_adapter=db_adapter) - - if "zep" in config and "mongo" not in config: - if "api_key" not in config["zep"]: - raise ValueError("Zep API key is required.") - memory_provider = MemoryRepository(zep_api_key=config["zep"].get("api_key")) + memory_provider = MemoryRepository(mongo_adapter=db_adapter) + # Create guardrails guardrail_config = config.get("guardrails", {}) input_guardrails: List[InputGuardrail] = SolanaAgentFactory._create_guardrails( guardrail_config.get("input", []) @@ -241,6 +231,7 @@ def create_from_config(config: Dict[str, Any]) -> QueryService: # Initialize Knowledge Base if configured knowledge_base = None + graph_memory_service = None kb_config = config.get("knowledge_base") # Requires both KB config section and MongoDB adapter if kb_config and db_adapter: @@ -252,10 +243,10 @@ def create_from_config(config: Dict[str, Any]) -> QueryService: # Determine OpenAI model and dimensions for KBService openai_model_name = openai_embed_config.get( - "model_name", "text-embedding-3-large" + "model_name", "text-embedding-3-small" ) - if openai_model_name == "text-embedding-3-large": - openai_dimensions = 3072 + if openai_model_name == "text-embedding-3-large": # pragma: no cover + openai_dimensions = 3072 # pragma: no cover elif openai_model_name == "text-embedding-3-small": # pragma: no cover openai_dimensions = 1536 # pragma: no cover else: # pragma: no cover @@ -279,7 +270,9 @@ def create_from_config(config: Dict[str, Any]) -> QueryService: ), # Reranking config use_reranking=pinecone_config.get("use_reranking", False), - rerank_model=pinecone_config.get("rerank_model"), + rerank_model=pinecone_config.get( + "rerank_model", "cohere-rerank-3.5" + ), rerank_top_k=pinecone_config.get("rerank_top_k", 3), initial_query_top_k_multiplier=pinecone_config.get( "initial_query_top_k_multiplier", 5 @@ -311,6 +304,34 @@ def create_from_config(config: Dict[str, Any]) -> QueryService: "Knowledge Base Service initialized successfully." ) # Use logger.info + # Create Graph Memory Service + graph_memory_config = pinecone_config.get("agent_memory", {}) + try: + # Create MongoDBGraphAdapter + mongo_graph_adapter = MongoDBGraphAdapter( + mongo_adapter=db_adapter, + node_collection=graph_memory_config.get( + "node_collection", "graph_nodes" + ), + edge_collection=graph_memory_config.get( + "edge_collection", "graph_edges" + ), + ) + + # Create GraphMemoryService + graph_memory_service = GraphMemoryService( + graph_adapter=mongo_graph_adapter, + pinecone_adapter=pinecone_adapter, + openai_adapter=llm_adapter, + embedding_model=graph_memory_config.get( + "embedding_model", "text-embedding-3-small" + ), + ) + logger.info("Graph Memory Service initialized successfully.") + except Exception as e: + logger.exception(f"Failed to initialize Graph Memory: {e}") + graph_memory_service = None + except Exception as e: # Use logger.exception to include traceback automatically logger.exception(f"Failed to initialize Knowledge Base: {e}") @@ -324,6 +345,7 @@ def create_from_config(config: Dict[str, Any]) -> QueryService: knowledge_base=knowledge_base, # Pass the potentially created KB kb_results_count=kb_config.get("results_count", 3) if kb_config else 3, input_guardrails=input_guardrails, + graph_memory=graph_memory_service, ) return query_service diff --git a/solana_agent/interfaces/providers/graph_storage.py b/solana_agent/interfaces/providers/graph_storage.py new file mode 100644 index 0000000..fc68578 --- /dev/null +++ b/solana_agent/interfaces/providers/graph_storage.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional + + +class GraphStorageProvider(ABC): + @abstractmethod + async def add_node(self, node: Dict[str, Any]) -> str: + pass + + @abstractmethod + async def add_edge(self, edge: Dict[str, Any]) -> str: + pass + + @abstractmethod + async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]: + pass + + @abstractmethod + async def get_edges( + self, node_id: str, direction: str = "both" + ) -> List[Dict[str, Any]]: + pass + + @abstractmethod + async def find_neighbors( + self, node_id: str, depth: int = 1 + ) -> List[Dict[str, Any]]: + pass + + @abstractmethod + async def temporal_query( + self, node_id: str, start_time: Optional[str], end_time: Optional[str] + ) -> List[Dict[str, Any]]: + pass diff --git a/solana_agent/interfaces/services/graph_memory.py b/solana_agent/interfaces/services/graph_memory.py new file mode 100644 index 0000000..ef25129 --- /dev/null +++ b/solana_agent/interfaces/services/graph_memory.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Any + + +class GraphMemoryService(ABC): + """ + Interface for a graph memory service. + """ + + @abstractmethod + async def add_episode( + self, + user_message: str, + assistant_message: str, + user_id: str, + ): + """ + Add an episode to the graph memory. + """ + pass + + @abstractmethod + async def search( + self, query: str, user_id: str, top_k: int = 5 + ) -> List[Dict[str, Any]]: + """ + Search the graph memory for relevant episodes. + """ + pass + + @abstractmethod + async def traverse(self, node_id: str, depth: int = 1) -> List[Dict[str, Any]]: + """ + Traverse the graph memory from a given node ID. + """ + pass diff --git a/solana_agent/repositories/memory.py b/solana_agent/repositories/memory.py index 7cc6a82..20ba681 100644 --- a/solana_agent/repositories/memory.py +++ b/solana_agent/repositories/memory.py @@ -1,9 +1,6 @@ import logging # Import logging -from copy import deepcopy from typing import List, Dict, Any, Optional, Tuple from datetime import datetime, timezone -from zep_cloud.client import AsyncZep as AsyncZepCloud -from zep_cloud.types import Message from solana_agent.interfaces.providers.memory import MemoryProvider from solana_agent.adapters.mongodb_adapter import MongoDBAdapter @@ -12,12 +9,11 @@ class MemoryRepository(MemoryProvider): - """Combined Zep and MongoDB implementation of MemoryProvider.""" + """MongoDB implementation of MemoryProvider.""" def __init__( self, mongo_adapter: Optional[MongoDBAdapter] = None, - zep_api_key: Optional[str] = None, ): """Initialize the combined memory provider.""" if not mongo_adapter: @@ -36,13 +32,8 @@ def __init__( except Exception as e: logger.error(f"Error initializing MongoDB: {e}") # Use logger.error - self.zep = None - # Initialize Zep - if zep_api_key: - self.zep = AsyncZepCloud(api_key=zep_api_key) - async def store(self, user_id: str, messages: List[Dict[str, Any]]) -> None: - """Store messages in both Zep and MongoDB.""" + """Store messages in MongoDB.""" if not user_id: raise ValueError("User ID cannot be None or empty") if not messages or not isinstance(messages, list): @@ -86,71 +77,22 @@ async def store(self, user_id: str, messages: List[Dict[str, Any]]) -> None: except Exception as e: logger.error(f"MongoDB storage error: {e}") # Use logger.error - # Store in Zep - if not self.zep: - return - - # Convert messages to Zep format - zep_messages = [] - for msg in messages: - if "role" in msg and "content" in msg: - content = self._truncate(deepcopy(msg["content"])) - zep_msg = Message( - role=msg["role"], - content=content, - role_type=msg["role"], - ) - zep_messages.append(zep_msg) - - # Add messages to Zep memory - if zep_messages: - try: - await self.zep.memory.add(session_id=user_id, messages=zep_messages) - except Exception: - try: - try: - await self.zep.user.add(user_id=user_id) - except Exception as e: - logger.error( - f"Zep user addition error: {e}" - ) # Use logger.error - - try: - await self.zep.memory.add_session( - session_id=user_id, user_id=user_id - ) - except Exception as e: - logger.error( - f"Zep session creation error: {e}" - ) # Use logger.error - await self.zep.memory.add(session_id=user_id, messages=zep_messages) - except Exception as e: - logger.error(f"Zep memory addition error: {e}") # Use logger.error - return - async def retrieve(self, user_id: str) -> str: - """Retrieve memory context from Zep and MongoDB.""" + """Retrieve memory context from MongoDB.""" try: memories = "" - if self.zep: - memory = await self.zep.memory.get(session_id=user_id) - if memory and memory.context: - memories = memory.context - if self.mongo: query = {"user_id": user_id} sort = [("timestamp", -1)] - limit = 1 + limit = 3 skip = 0 - mongo_memory = self.mongo.find( + results = self.mongo.find( self.collection, query, sort=sort, limit=limit, skip=skip ) - if mongo_memory: - for doc in mongo_memory: - user_message = doc.get("user_message") - assistant_message = doc.get("assistant_message") - if user_message and assistant_message: - memories += f"\nUser: {user_message}\nAssistant: {assistant_message}" + if results: + for result in results: + memories += f"User: {result.get('user_message')}\n" + memories += f"Assistant: {result.get('assistant_message')}\n" return memories except Exception as e: @@ -165,19 +107,6 @@ async def delete(self, user_id: str) -> None: except Exception as e: logger.error(f"MongoDB deletion error: {e}") # Use logger.error - if not self.zep: - return - - try: - await self.zep.memory.delete(session_id=user_id) - except Exception as e: - logger.error(f"Zep memory deletion error: {e}") # Use logger.error - - try: - await self.zep.user.delete(user_id=user_id) - except Exception as e: - logger.error(f"Zep user deletion error: {e}") # Use logger.error - def find( self, collection: str, @@ -201,22 +130,3 @@ def count_documents(self, collection: str, query: Dict) -> int: if not self.mongo: return 0 return self.mongo.count_documents(collection, query) - - def _truncate(self, text: str, limit: int = 2500) -> str: - """Truncate text to be within limits.""" - if text is None: - raise AttributeError("Cannot truncate None text") - - if not text: - return "" - - if len(text) <= limit: - return text - - # Try to truncate at last period before limit - last_period = text.rfind(".", 0, limit) - if last_period > 0: - return text[: last_period + 1] - - # If no period found, truncate at limit and add ellipsis - return text[: limit - 3] + "..." diff --git a/solana_agent/services/graph_memory.py b/solana_agent/services/graph_memory.py new file mode 100644 index 0000000..dc569e8 --- /dev/null +++ b/solana_agent/services/graph_memory.py @@ -0,0 +1,102 @@ +import datetime +import uuid +from typing import Dict, Any, List +from solana_agent.adapters.openai_adapter import OpenAIAdapter +from solana_agent.adapters.pinecone_adapter import PineconeAdapter +from solana_agent.adapters.mongodb_graph_adapter import MongoDBGraphAdapter +from solana_agent.interfaces.services.graph_memory import ( + GraphMemoryService as GraphMemoryServiceInterface, +) + + +class GraphMemoryService(GraphMemoryServiceInterface): + def __init__( + self, + graph_adapter: MongoDBGraphAdapter, + pinecone_adapter: PineconeAdapter, + openai_adapter: OpenAIAdapter, + embedding_model: str = "text-embedding-3-small", + ): + self.graph = graph_adapter + self.pinecone = pinecone_adapter + self.openai = openai_adapter + self.embedding_model = embedding_model + + async def add_episode( + self, + user_message: str, + assistant_message: str, + user_id: str, + ): + entities = [ + {"type": "user_message", "text": user_message, "user_id": user_id}, + { + "type": "assistant_message", + "text": assistant_message, + "user_id": user_id, + }, + ] + episode_id = str(uuid.uuid4()) + node_ids = [] + for entity in entities: + entity["episode_id"] = episode_id + node_id = await self.graph.add_node(entity) + node_ids.append(node_id) + edge = { + "source": node_ids[0], + "target": node_ids[1], + "type": "reply", + "episode_id": episode_id, + "user_id": user_id, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + await self.graph.add_edge(edge) + # Save vectors in user-specific namespace + namespace = f"{user_id}_memory" + for node_id, entity in zip(node_ids, entities): + embedding = await self.openai.embed_text( + entity["text"], model=self.embedding_model + ) + await self.pinecone.upsert( + [ + { + "id": node_id, + "values": embedding, + "metadata": {"text": self._truncate_text(entity["text"])}, + } + ], + namespace=namespace, + ) + return episode_id + + async def search( + self, query: str, user_id: str, top_k: int = 5 + ) -> List[Dict[str, Any]]: + embedding = await self.openai.embed_text(query, model=self.embedding_model) + namespace = f"{user_id}_memory" + results = await self.pinecone.query_and_rerank( + vector=embedding, + query_text_for_rerank=query, + top_k=top_k, + namespace=namespace, + ) + node_ids = [r["id"] for r in results] + # Only return nodes that match user_id + nodes = [] + for nid in node_ids: + node = await self.graph.get_node(nid) + if node and node.get("user_id") == user_id: + nodes.append(node) + return nodes + + async def traverse(self, node_id: str, depth: int = 1) -> List[Dict[str, Any]]: + return await self.graph.find_neighbors(node_id, depth=depth) + + # pinecone has a 40kb character limit for text + def _truncate_text(self, text: str, max_length: int = 40960) -> str: + """ + Truncate text to a maximum length. + """ + if len(text) > max_length - 3: + return text[:max_length] + "..." + return text diff --git a/solana_agent/services/knowledge_base.py b/solana_agent/services/knowledge_base.py index f01ab51..ce567de 100644 --- a/solana_agent/services/knowledge_base.py +++ b/solana_agent/services/knowledge_base.py @@ -33,7 +33,7 @@ def __init__( pinecone_adapter: PineconeAdapter, mongodb_adapter: MongoDBAdapter, openai_api_key: str, - openai_model_name: str = "text-embedding-3-large", + openai_model_name: str = "text-embedding-3-small", collection_name: str = "knowledge_documents", rerank_results: bool = False, rerank_top_k: int = 3, @@ -70,7 +70,7 @@ def __init__( ) # Determine expected embedding dimensions based on model name - if openai_model_name == "text-embedding-3-large": + if openai_model_name == "text-embedding-3-small": openai_dimensions = 3072 elif openai_model_name == "text-embedding-3-small": openai_dimensions = 1536 diff --git a/solana_agent/services/query.py b/solana_agent/services/query.py index b48a387..431ffda 100644 --- a/solana_agent/services/query.py +++ b/solana_agent/services/query.py @@ -17,12 +17,13 @@ from solana_agent.interfaces.providers.memory import ( MemoryProvider as MemoryProviderInterface, ) -from solana_agent.interfaces.services.knowledge_base import ( - KnowledgeBaseService as KnowledgeBaseInterface, +from solana_agent.services.knowledge_base import ( + KnowledgeBaseService, ) from solana_agent.interfaces.guardrails.guardrails import ( InputGuardrail, ) +from solana_agent.services.graph_memory import GraphMemoryService from solana_agent.services.agent import AgentService from solana_agent.services.routing import RoutingService @@ -38,9 +39,10 @@ def __init__( agent_service: AgentService, routing_service: RoutingService, memory_provider: Optional[MemoryProviderInterface] = None, - knowledge_base: Optional[KnowledgeBaseInterface] = None, + knowledge_base: Optional[KnowledgeBaseService] = None, kb_results_count: int = 3, input_guardrails: List[InputGuardrail] = None, + graph_memory: Optional[GraphMemoryService] = None, ): """Initialize the query service. @@ -51,6 +53,7 @@ def __init__( knowledge_base: Optional provider for knowledge base interactions kb_results_count: Number of results to retrieve from knowledge base input_guardrails: List of input guardrail instances + graph_memory: Optional graph memory service instance """ self.agent_service = agent_service self.routing_service = routing_service @@ -58,6 +61,7 @@ def __init__( self.knowledge_base = knowledge_base self.kb_results_count = kb_results_count self.input_guardrails = input_guardrails or [] + self.graph_memory = graph_memory async def process( self, @@ -168,6 +172,18 @@ async def process( # Store simple interaction in memory (using processed user_text) if self.memory_provider: await self._store_conversation(user_id, user_text, response) + # --- Store in graph memory if available --- + if self.graph_memory: + try: + await self.graph_memory.add_episode( + user_message=user_text, + assistant_message=response, + user_id=user_id, + ) + except Exception as e: + logger.error( + f"Error storing in graph memory: {e}", exc_info=True + ) return # --- 4. Get Memory Context --- @@ -232,6 +248,43 @@ async def process( combined_context += "CRITICAL PRIORITIZATION GUIDE: For factual or current information, prioritize Knowledge Base results and Tool results (if applicable) over Conversation History.\n\n" logger.debug(f"Combined context length: {len(combined_context)}") + # --- 7a. Retrieve Graph Memory Context (NEW) --- + graph_context = "" + if self.graph_memory: + try: + # Use processed user_text for graph memory search + graph_results = await self.graph_memory.search( + user_id=user_id, + query=user_text, + top_k=3, + ) + if graph_results: + graph_context = ( + "**GRAPH MEMORY (Relevant Episodes/Entities):**\n" + ) + for i, node in enumerate(graph_results, 1): + node_type = node.get("type", "node") + node_text = node.get("text", "") + episode_id = node.get("episode_id", "") + graph_context += f"[{i}] ({node_type}) {node_text}" + if episode_id: + graph_context += f" (Episode: {episode_id})" + graph_context += "\n\n" + logger.info( + f"Retrieved {len(graph_results)} results from Graph Memory." + ) + else: + logger.info("No relevant results found in Graph Memory.") + except Exception as e: + logger.error(f"Error retrieving graph memory: {e}", exc_info=True) + + if graph_context: + combined_context += f"{graph_context}\n" + + if memory_context or kb_context or graph_context: + combined_context += "CRITICAL PRIORITIZATION GUIDE: For factual or current information, prioritize Knowledge Base, Graph Memory, and Tool results (if applicable) over Conversation History.\n\n" + logger.debug(f"Combined context length: {len(combined_context)}") + # --- 8. Generate Response --- # Pass the processed user_text and images to the agent service if output_format == "audio": @@ -257,6 +310,20 @@ async def process( user_message=user_text, # Store only text part of user query assistant_message=self.agent_service.last_text_response, ) + # --- Store in graph memory if available --- + if self.graph_memory and hasattr( + self.agent_service, "last_text_response" + ): + try: + await self.graph_memory.add_episode( + user_message=user_text, + assistant_message=self.agent_service.last_text_response, + user_id=user_id, + ) + except Exception as e: + logger.error( + f"Error storing in graph memory: {e}", exc_info=True + ) else: full_text_response = "" async for chunk in self.agent_service.generate_response( @@ -279,6 +346,18 @@ async def process( user_message=user_text, # Store only text part of user query assistant_message=full_text_response, ) + # --- Store in graph memory if available --- + if self.graph_memory and full_text_response: + try: + await self.graph_memory.add_episode( + user_message=user_text, + assistant_message=full_text_response, + user_id=user_id, + ) + except Exception as e: + logger.error( + f"Error storing in graph memory: {e}", exc_info=True + ) except Exception as e: import traceback diff --git a/tests/unit/factories/test_agent_factory.py b/tests/unit/factories/test_agent_factory.py index ff0f509..099c24d 100644 --- a/tests/unit/factories/test_agent_factory.py +++ b/tests/unit/factories/test_agent_factory.py @@ -703,7 +703,7 @@ def test_create_from_config_with_knowledge_base( pinecone_adapter=mock_pinecone_instance, mongodb_adapter=mock_mongo_instance, openai_api_key="test-openai-key", # From config - openai_model_name="text-embedding-3-large", + openai_model_name="text-embedding-3-small", collection_name="knowledge_documents", # From config rerank_results=True, # From config (pinecone.use_reranking) rerank_top_k=3, # From config (knowledge_base.results_count) diff --git a/tests/unit/repositories/test_memory_repo.py b/tests/unit/repositories/test_memory_repo.py index 9a060d7..f574fe2 100644 --- a/tests/unit/repositories/test_memory_repo.py +++ b/tests/unit/repositories/test_memory_repo.py @@ -6,12 +6,11 @@ """ import pytest -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import MagicMock, patch from datetime import datetime from solana_agent.repositories.memory import MemoryRepository from solana_agent.adapters.mongodb_adapter import MongoDBAdapter -from zep_cloud.types import Memory @pytest.fixture @@ -27,18 +26,6 @@ def mock_mongo_adapter(): return adapter -@pytest.fixture -def mock_zep(): - """Create a mock Zep client.""" - mock = AsyncMock() - mock.user = AsyncMock() - mock.memory = AsyncMock() - memory = MagicMock(spec=Memory) - memory.context = "Test memory context" - mock.memory.get.return_value = memory - return mock - - @pytest.fixture def valid_messages(): """Valid message list for testing.""" @@ -68,7 +55,6 @@ def test_init_default(self): repo = MemoryRepository() assert repo.mongo is None assert repo.collection is None - assert repo.zep is None def test_init_mongo_only(self, mock_mongo_adapter): """Test initialization with MongoDB only.""" @@ -84,12 +70,6 @@ def test_init_mongo_error(self, mock_mongo_adapter): repo = MemoryRepository(mongo_adapter=mock_mongo_adapter) assert repo.mongo == mock_mongo_adapter - @patch("solana_agent.repositories.memory.AsyncZepCloud") - def test_init_zep_cloud(self, mock_zep_cloud): - """Test initialization with Zep Cloud.""" - MemoryRepository(zep_api_key="test_key") - mock_zep_cloud.assert_called_once_with(api_key="test_key") - @pytest.mark.asyncio async def test_store_validation_errors(self, mock_mongo_adapter, invalid_messages): """Test message validation errors.""" @@ -133,96 +113,6 @@ async def test_store_mongo_error(self, mock_mongo_adapter, valid_messages): await repo.store("user123", valid_messages) mock_mongo_adapter.insert_one.assert_called_once() - @pytest.mark.asyncio - async def test_store_zep_direct_success(self, mock_zep, valid_messages): - """Test successful direct Zep storage without fallback.""" - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - - # Memory.add will succeed on first try - await repo.store("user123", valid_messages) - - # Verify direct path - mock_zep.memory.add.assert_called_once() - mock_zep.user.add.assert_not_called() - mock_zep.memory.add_session.assert_not_called() - - @pytest.mark.asyncio - async def test_store_zep_success(self, mock_zep, valid_messages): - """Test successful Zep storage.""" - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - - # Make the first memory.add call fail to trigger the fallback path - mock_zep.memory.add.side_effect = [Exception("First call fails"), None] - - await repo.store("user123", valid_messages) - - # Now verify the fallback path was called - mock_zep.user.add.assert_called_once_with(user_id="user123") - mock_zep.memory.add_session.assert_called_once_with( - session_id="user123", user_id="user123" - ) - # Verify memory.add was called twice (once failing, once succeeding) - assert mock_zep.memory.add.call_count == 2 - - @pytest.mark.asyncio - async def test_store_zep_errors(self, mock_zep, valid_messages): - """Test handling Zep storage errors.""" - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - - # Make first memory.add fail, then user.add fail - mock_zep.memory.add.side_effect = Exception("Memory error") - mock_zep.user.add.side_effect = Exception("User error") - - await repo.store("user123", valid_messages) - - # Verify add_session was still called despite user.add failing - mock_zep.memory.add_session.assert_called_once_with( - session_id="user123", user_id="user123" - ) - - @pytest.mark.asyncio - async def test_store_zep_session_creation_error(self, mock_zep, valid_messages): - """Test handling the specific case where session creation fails.""" - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - - # Configure mocks to hit the specific error path: - # 1. First memory.add fails - mock_zep.memory.add.side_effect = [ - Exception("Session not found"), # First call fails - None, # Second call succeeds (we'll reach this if code continues) - ] - - # 2. User creation succeeds - mock_zep.user.add.return_value = None - - # 3. Session creation raises exception (this is what we want to test) - mock_zep.memory.add_session.side_effect = Exception("Session creation failed") - - # Call the method - await repo.store("user123", valid_messages) - - # Verify: - # - Initial memory.add was called and failed - # - User.add was called and succeeded - # - Memory.add_session was called and failed (our target scenario) - # - Code continued and tried to add messages again - mock_zep.memory.add.assert_called() - mock_zep.user.add.assert_called_once_with(user_id="user123") - mock_zep.memory.add_session.assert_called_once_with( - session_id="user123", user_id="user123" - ) - - # Verify we reached the print statement by checking call count - # (2 calls means we tried the second add after the session error) - assert mock_zep.memory.add.call_count == 2 - - # Optional: add a mock for print and verify it was called with the error message - # This requires patch("builtins.print") in the test setup - @pytest.mark.asyncio async def test_retrieve_success_no_zep(self): """Test successful memory retrieval.""" @@ -231,63 +121,52 @@ async def test_retrieve_success_no_zep(self): assert result == "" @pytest.mark.asyncio - async def test_retrieve_memory_context_success(self, mock_zep): - """Test successful retrieval of memory context from Zep.""" - # Setup - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - - # Create a mock for the memory object with a context attribute - mock_memory = MagicMock() - mock_memory.context = "Sample memory context data" - - # Configure the mock to return our memory object - mock_zep.memory.get.return_value = mock_memory - - # Call retrieve method + async def test_retrieve_memory_context_success(self, mock_mongo_adapter): + """Test successful retrieval of memory context from MongoDB.""" + repo = MemoryRepository(mongo_adapter=mock_mongo_adapter) + # Mock MongoDB find to return recent user/assistant messages + mock_mongo_adapter.find.return_value = [ + {"user_message": "Hello", "assistant_message": "Hi there"}, + {"user_message": "How are you?", "assistant_message": "I'm good!"}, + ] result = await repo.retrieve("test_user") - - # Verify correct behavior - mock_zep.memory.get.assert_called_once_with(session_id="test_user") - assert result == "Sample memory context data" + # Should contain both pairs in order + assert "User: Hello" in result + assert "Assistant: Hi there" in result + assert "User: How are you?" in result + assert "Assistant: I'm good!" in result + mock_mongo_adapter.find.assert_called_once_with( + "conversations", + {"user_id": "test_user"}, + sort=[("timestamp", -1)], + limit=3, + skip=0, + ) @pytest.mark.asyncio - async def test_retrieve_errors(self, mock_zep): - """Test memory retrieval errors.""" - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - - # Test None memory - mock_zep.memory.get.return_value = None - result = await repo.retrieve("user123") - assert result == "" + async def test_retrieve_errors(self, mock_mongo_adapter): + """Test memory retrieval errors and empty results.""" + repo = MemoryRepository(mongo_adapter=mock_mongo_adapter) - # Test missing context - memory = MagicMock(spec=Memory) - memory.context = None - mock_zep.memory.get.return_value = memory + # Test no results + mock_mongo_adapter.find.return_value = [] result = await repo.retrieve("user123") assert result == "" - # Test retrieval error - mock_zep.memory.get.side_effect = Exception("Retrieval error") + # Test MongoDB error + mock_mongo_adapter.find.side_effect = Exception("Find error") result = await repo.retrieve("user123") assert result == "" @pytest.mark.asyncio - async def test_delete_success(self, mock_mongo_adapter, mock_zep): + async def test_delete_success(self, mock_mongo_adapter): """Test successful memory deletion.""" - repo = MemoryRepository( - mongo_adapter=mock_mongo_adapter, zep_api_key="test_key" - ) - repo.zep = mock_zep + repo = MemoryRepository(mongo_adapter=mock_mongo_adapter) await repo.delete("user123") mock_mongo_adapter.delete_all.assert_called_once_with( "conversations", {"user_id": "user123"} ) - mock_zep.memory.delete.assert_called_once_with(session_id="user123") - mock_zep.user.delete.assert_called_once_with(user_id="user123") def test_find_success(self, mock_mongo_adapter): """Test successful document find.""" @@ -316,30 +195,6 @@ def test_count_documents_success_no_mongo(self, mock_mongo_adapter): assert result == 0 mock_mongo_adapter.count_documents.assert_not_called() - def test_truncate_text(self): - """Test text truncation.""" - repo = MemoryRepository() - - # Test within limit - assert repo._truncate("Short text") == "Short text" - - # Test at period - assert ( - repo._truncate("First sentence. Second sentence.", 20) == "First sentence." - ) - - # Test with ellipsis - result = repo._truncate("a" * 3000) - assert len(result) <= 2503 - assert result.endswith("...") - - # Test empty text - assert repo._truncate("") == "" - - # Test None - with pytest.raises(AttributeError): - repo._truncate(None) - @pytest.mark.asyncio async def test_store_empty_user_id(self): """Test storing with empty user_id.""" @@ -358,24 +213,6 @@ async def test_store_missing_messages_pair(self, mock_mongo_adapter): await repo.store("user123", messages) mock_mongo_adapter.insert_one.assert_not_called() - @pytest.mark.asyncio - async def test_store_zep_session_error(self, mock_zep, valid_messages): - """Test Zep session creation failure.""" - mock_zep.memory.add_session.side_effect = Exception("Session error") - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - await repo.store("user123", valid_messages) - mock_zep.memory.add.assert_called_once() - - @pytest.mark.asyncio - async def test_store_zep_memory_error(self, mock_zep, valid_messages): - """Test Zep memory addition failure.""" - mock_zep.memory.add.side_effect = Exception("Memory error") - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - await repo.store("user123", valid_messages) - mock_zep.memory.add_session.assert_called_once() - @pytest.mark.asyncio async def test_delete_mongo_error(self, mock_mongo_adapter): """Test MongoDB deletion error.""" @@ -384,24 +221,6 @@ async def test_delete_mongo_error(self, mock_mongo_adapter): await repo.delete("user123") mock_mongo_adapter.delete_all.assert_called_once() - @pytest.mark.asyncio - async def test_delete_zep_memory_error(self, mock_zep): - """Test Zep memory deletion error.""" - mock_zep.memory.delete.side_effect = Exception("Memory delete error") - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - await repo.delete("user123") - mock_zep.user.delete.assert_called_once() - - @pytest.mark.asyncio - async def test_delete_zep_user_error(self, mock_zep): - """Test Zep user deletion error.""" - mock_zep.user.delete.side_effect = Exception("User delete error") - repo = MemoryRepository(zep_api_key="test_key") - repo.zep = mock_zep - await repo.delete("user123") - mock_zep.memory.delete.assert_called_once() - def test_find_mongo_error(self, mock_mongo_adapter): """Test MongoDB find error.""" mock_mongo_adapter.find.side_effect = Exception("Find error") diff --git a/tests/unit/services/test_knowledge_base.py b/tests/unit/services/test_knowledge_base.py index 9b0549f..9831f12 100644 --- a/tests/unit/services/test_knowledge_base.py +++ b/tests/unit/services/test_knowledge_base.py @@ -74,7 +74,7 @@ def knowledge_base_service( pinecone_adapter=mock_pinecone_adapter, mongodb_adapter=mock_mongodb_adapter, openai_api_key="fake-key", # Required by init - openai_model_name="text-embedding-3-large", # Match dimension logic + openai_model_name="text-embedding-3-small", # Match dimension logic ) return service