Skip to content

Commit be3ed2a

Browse files
working
1 parent b98850f commit be3ed2a

File tree

8 files changed

+710
-374
lines changed

8 files changed

+710
-374
lines changed

poetry.lock

Lines changed: 362 additions & 362 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "solana-agent"
3-
version = "29.1.3"
3+
version = "29.2.0-dev80"
44
description = "AI Agents for Solana"
55
authors = ["Bevan Hunt <bevan@bevanhunt.com>"]
66
license = "MIT"
@@ -24,18 +24,18 @@ python_paths = [".", "tests"]
2424

2525
[tool.poetry.dependencies]
2626
python = ">=3.12,<4.0"
27-
openai = "1.78.1"
27+
openai = "1.79.0"
2828
pydantic = ">=2"
29-
pymongo = "4.12.1"
30-
zep-cloud = "2.12.1"
31-
instructor = "1.8.1"
29+
pymongo = "4.13.0"
30+
zep-cloud = "2.12.3"
31+
instructor = "1.8.2"
3232
pinecone = "6.0.2"
33-
llama-index-core = "0.12.35"
33+
llama-index-core = "0.12.37"
3434
llama-index-embeddings-openai = "0.3.1"
3535
pypdf = "5.5.0"
3636
scrubadub = "2.0.1"
37-
logfire = "3.15.1"
38-
typer = "0.15.3"
37+
logfire = "3.16.0"
38+
typer = "0.15.4"
3939
rich = ">=13,<14.0"
4040
pillow = "11.2.1"
4141

@@ -50,7 +50,7 @@ sphinx-rtd-theme = "^3.0.2"
5050
myst-parser = "^4.0.1"
5151
sphinx-autobuild = "^2024.10.3"
5252
mongomock = "^4.3.0"
53-
ruff = "^0.11.9"
53+
ruff = "^0.11.10"
5454

5555
[tool.poetry.scripts]
5656
solana-agent = "solana_agent.cli:app"
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import uuid
2+
from typing import Dict, Any, List, Optional
3+
from solana_agent.interfaces.providers.graph_storage import GraphStorageProvider
4+
from solana_agent.adapters.mongodb_adapter import MongoDBAdapter
5+
6+
7+
class MongoDBGraphAdapter(GraphStorageProvider):
8+
def __init__(
9+
self,
10+
mongo_adapter: MongoDBAdapter,
11+
node_collection: str = "graph_nodes",
12+
edge_collection: str = "graph_edges",
13+
):
14+
self.mongo = mongo_adapter
15+
self.node_collection = node_collection
16+
self.edge_collection = edge_collection
17+
18+
async def add_node(self, node: Dict[str, Any]) -> str:
19+
node = dict(node)
20+
node["uuid"] = node.get("uuid", str(uuid.uuid4()))
21+
self.mongo.insert_one(self.node_collection, node)
22+
return node["uuid"]
23+
24+
async def add_edge(self, edge: Dict[str, Any]) -> str:
25+
edge = dict(edge)
26+
edge["uuid"] = edge.get("uuid", str(uuid.uuid4()))
27+
return self.mongo.insert_one(self.edge_collection, edge)
28+
29+
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
30+
return self.mongo.find_one(self.node_collection, {"uuid": node_id})
31+
32+
async def get_edges(
33+
self, node_id: str, direction: str = "both"
34+
) -> List[Dict[str, Any]]:
35+
if direction == "out":
36+
query = {"source": node_id}
37+
elif direction == "in":
38+
query = {"target": node_id}
39+
else:
40+
query = {"$or": [{"source": node_id}, {"target": node_id}]}
41+
return self.mongo.find(self.edge_collection, query)
42+
43+
async def find_neighbors(
44+
self, node_id: str, depth: int = 1
45+
) -> List[Dict[str, Any]]:
46+
neighbors = set()
47+
current = {node_id}
48+
for _ in range(depth):
49+
edges = await self.get_edges(list(current)[0])
50+
for edge in edges:
51+
neighbors.add(edge.get("source"))
52+
neighbors.add(edge.get("target"))
53+
current = neighbors
54+
neighbors.discard(node_id)
55+
return [await self.get_node(nid) for nid in neighbors if nid]
56+
57+
async def temporal_query(
58+
self, node_id: str, start_time: Optional[str], end_time: Optional[str]
59+
) -> List[Dict[str, Any]]:
60+
query = {"$or": [{"source": node_id}, {"target": node_id}]}
61+
if start_time or end_time:
62+
query["timestamp"] = {}
63+
if start_time:
64+
query["timestamp"]["$gte"] = start_time
65+
if end_time:
66+
query["timestamp"]["$lte"] = end_time
67+
return self.mongo.find(self.edge_collection, query)

solana_agent/factories/agent_factory.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from typing import Dict, Any, List
1111

1212
# Service imports
13+
from solana_agent.adapters.mongodb_graph_adapter import MongoDBGraphAdapter
1314
from solana_agent.adapters.pinecone_adapter import PineconeAdapter
1415
from solana_agent.interfaces.guardrails.guardrails import (
1516
InputGuardrail,
1617
OutputGuardrail,
1718
)
19+
from solana_agent.services.graph_memory import GraphMemoryService
1820
from solana_agent.services.query import QueryService
1921
from solana_agent.services.agent import AgentService
2022
from solana_agent.services.routing import RoutingService
@@ -241,6 +243,7 @@ def create_from_config(config: Dict[str, Any]) -> QueryService:
241243

242244
# Initialize Knowledge Base if configured
243245
knowledge_base = None
246+
graph_memory_service = None
244247
kb_config = config.get("knowledge_base")
245248
# Requires both KB config section and MongoDB adapter
246249
if kb_config and db_adapter:
@@ -311,6 +314,34 @@ def create_from_config(config: Dict[str, Any]) -> QueryService:
311314
"Knowledge Base Service initialized successfully."
312315
) # Use logger.info
313316

317+
# Create Graph Memory Service
318+
graph_memory_config = pinecone_config.get("agent_memory", {})
319+
try:
320+
# Create MongoDBGraphAdapter
321+
mongo_graph_adapter = MongoDBGraphAdapter(
322+
mongo_adapter=db_adapter,
323+
node_collection=graph_memory_config.get(
324+
"node_collection", "graph_nodes"
325+
),
326+
edge_collection=graph_memory_config.get(
327+
"edge_collection", "graph_edges"
328+
),
329+
)
330+
331+
# Create GraphMemoryService
332+
graph_memory_service = GraphMemoryService(
333+
graph_adapter=mongo_graph_adapter,
334+
pinecone_adapter=pinecone_adapter,
335+
openai_adapter=llm_adapter,
336+
embedding_model=graph_memory_config.get(
337+
"embedding_model", "text-embedding-3-large"
338+
),
339+
)
340+
logger.info("Graph Memory Service initialized successfully.")
341+
except Exception as e:
342+
logger.exception(f"Failed to initialize Graph Memory: {e}")
343+
graph_memory_service = None
344+
314345
except Exception as e:
315346
# Use logger.exception to include traceback automatically
316347
logger.exception(f"Failed to initialize Knowledge Base: {e}")
@@ -324,6 +355,7 @@ def create_from_config(config: Dict[str, Any]) -> QueryService:
324355
knowledge_base=knowledge_base, # Pass the potentially created KB
325356
kb_results_count=kb_config.get("results_count", 3) if kb_config else 3,
326357
input_guardrails=input_guardrails,
358+
graph_memory=graph_memory_service,
327359
)
328360

329361
return query_service
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, Any, List, Optional
3+
4+
5+
class GraphStorageProvider(ABC):
6+
@abstractmethod
7+
async def add_node(self, node: Dict[str, Any]) -> str:
8+
pass
9+
10+
@abstractmethod
11+
async def add_edge(self, edge: Dict[str, Any]) -> str:
12+
pass
13+
14+
@abstractmethod
15+
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
16+
pass
17+
18+
@abstractmethod
19+
async def get_edges(
20+
self, node_id: str, direction: str = "both"
21+
) -> List[Dict[str, Any]]:
22+
pass
23+
24+
@abstractmethod
25+
async def find_neighbors(
26+
self, node_id: str, depth: int = 1
27+
) -> List[Dict[str, Any]]:
28+
pass
29+
30+
@abstractmethod
31+
async def temporal_query(
32+
self, node_id: str, start_time: Optional[str], end_time: Optional[str]
33+
) -> List[Dict[str, Any]]:
34+
pass
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, List, Any
3+
4+
5+
class GraphMemoryService(ABC):
6+
"""
7+
Interface for a graph memory service.
8+
"""
9+
10+
@abstractmethod
11+
async def add_episode(
12+
self,
13+
user_message: str,
14+
assistant_message: str,
15+
user_id: str,
16+
):
17+
"""
18+
Add an episode to the graph memory.
19+
"""
20+
pass
21+
22+
@abstractmethod
23+
async def search(
24+
self, query: str, user_id: str, top_k: int = 5
25+
) -> List[Dict[str, Any]]:
26+
"""
27+
Search the graph memory for relevant episodes.
28+
"""
29+
pass
30+
31+
@abstractmethod
32+
async def traverse(self, node_id: str, depth: int = 1) -> List[Dict[str, Any]]:
33+
"""
34+
Traverse the graph memory from a given node ID.
35+
"""
36+
pass

solana_agent/services/graph_memory.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import datetime
2+
import uuid
3+
from typing import Dict, Any, List
4+
from solana_agent.adapters.openai_adapter import OpenAIAdapter
5+
from solana_agent.adapters.pinecone_adapter import PineconeAdapter
6+
from solana_agent.adapters.mongodb_graph_adapter import MongoDBGraphAdapter
7+
from solana_agent.interfaces.services.graph_memory import (
8+
GraphMemoryService as GraphMemoryServiceInterface,
9+
)
10+
11+
12+
class GraphMemoryService(GraphMemoryServiceInterface):
13+
def __init__(
14+
self,
15+
graph_adapter: MongoDBGraphAdapter,
16+
pinecone_adapter: PineconeAdapter,
17+
openai_adapter: OpenAIAdapter,
18+
embedding_model: str = "text-embedding-3-large",
19+
):
20+
self.graph = graph_adapter
21+
self.pinecone = pinecone_adapter
22+
self.openai = openai_adapter
23+
self.embedding_model = embedding_model
24+
25+
async def add_episode(
26+
self,
27+
user_message: str,
28+
assistant_message: str,
29+
user_id: str,
30+
):
31+
entities = [
32+
{"type": "user_message", "text": user_message, "user_id": user_id},
33+
{
34+
"type": "assistant_message",
35+
"text": assistant_message,
36+
"user_id": user_id,
37+
},
38+
]
39+
episode_id = str(uuid.uuid4())
40+
node_ids = []
41+
for entity in entities:
42+
entity["episode_id"] = episode_id
43+
node_id = await self.graph.add_node(entity)
44+
node_ids.append(node_id)
45+
edge = {
46+
"source": node_ids[0],
47+
"target": node_ids[1],
48+
"type": "reply",
49+
"episode_id": episode_id,
50+
"user_id": user_id,
51+
"timestamp": datetime.datetime.now(datetime.timezone.utc),
52+
}
53+
await self.graph.add_edge(edge)
54+
# Save vectors in user-specific namespace
55+
namespace = f"{user_id}_memory"
56+
for node_id, entity in zip(node_ids, entities):
57+
embedding = await self.openai.embed_text(
58+
entity["text"], model=self.embedding_model
59+
)
60+
await self.pinecone.upsert(
61+
[{"id": node_id, "values": embedding}],
62+
namespace=namespace,
63+
)
64+
return episode_id
65+
66+
async def search(
67+
self, query: str, user_id: str, top_k: int = 5
68+
) -> List[Dict[str, Any]]:
69+
embedding = await self.openai.embed_text(query, model=self.embedding_model)
70+
namespace = f"{user_id}_memory"
71+
results = await self.pinecone.query_and_rerank(
72+
vector=embedding,
73+
query_text_for_rerank=query,
74+
top_k=top_k,
75+
namespace=namespace,
76+
)
77+
node_ids = [r["id"] for r in results]
78+
# Only return nodes that match user_id
79+
nodes = []
80+
for nid in node_ids:
81+
print(f"Node ID: {nid}")
82+
node = await self.graph.get_node(nid)
83+
if node and node.get("user_id") == user_id:
84+
nodes.append(node)
85+
return nodes
86+
87+
async def traverse(self, node_id: str, depth: int = 1) -> List[Dict[str, Any]]:
88+
return await self.graph.find_neighbors(node_id, depth=depth)

0 commit comments

Comments
 (0)