Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ This tool searches for Confluence pages using a given query string with RAG (Ret
- Result Limit: Maximum number of pages to retrieve from Confluence API
- RAG Settings:
- embedding_model_save_path: Path to save embedding models
- embedding_model_name: Name/path of the embedding model to use
- embedding_backend: Which embedding backend to use. Allowed values: "ollama", "sentence_transformers"
- embedding_model_name: Name/path of the embedding model to use (used for sentence-transformers backend)
- ollama_host: Base URL of local Ollama server (used when embedding_backend='ollama')
- ollama_model_name: Name of the Ollama embedding model (used when embedding_backend='ollama')
- cpu_only: Run the tool on CPU only (vs GPU)
- chunk_size: Maximum size of each content chunk for processing
- chunk_overlap: Overlap between consecutive chunks
Expand Down
211 changes: 192 additions & 19 deletions confluence_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
author_url: https://github.com/RomainNeup
funding_url: https://github.com/sponsors/RomainNeup
requirements: markdownify, sentence-transformers, numpy, rank_bm25, scikit-learn
version: 0.4.0
version: 0.5.0
changelog:
- 0.0.1 - Initial code base.
- 0.0.2 - Fix Valves variables
Expand All @@ -24,6 +24,7 @@
- 0.2.6 - Add terms splitting option
- 0.3.0 - Add settings for ssl verification
- 0.4.0 - Add support for included/exluded confluence spaces in user settings
- 0.5.0 - Add optional Ollama embedding backend support (local Ollama instance)
"""

import base64
Expand All @@ -41,6 +42,11 @@
from rank_bm25 import BM25Okapi
from sklearn.neighbors import NearestNeighbors

# New environment variables for Ollama / backend selection
RAG_EMBEDDING_BACKEND = os.environ.get("RAG_EMBEDDING_BACKEND", "ollama").lower()
OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://127.0.0.1:11434") #change address to match you're ollama instance
OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "snowflake-arctic-embed2:568m") #change model to the one you use

# Get environment variables
DEFAULT_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
Expand Down Expand Up @@ -291,12 +297,101 @@ def filter_similar_embeddings(
return list(sorted(included_idxs))


class OllamaEmbeddingClient:
"""
Minimal client for Ollama embedding endpoint.
Uses POST {host}/api/embeddings with JSON: {"model": "<model>", "prompt": "<text>"}
It returns a numpy array of shape (n_texts, dim) for list inputs, and 1D array for single string.
The client parses common response shapes: {"embedding": [...]}, {"embeddings": [...]}, {"data": [{"embedding": [...]}]}
"""

def __init__(self, host: str = OLLAMA_HOST, model: str = OLLAMA_MODEL, timeout: int = 30):
self.base_url = host.rstrip("/")
# endpoint per Ollama docs/examples
self.endpoint = f"{self.base_url}/api/embeddings"
self.model = model
self.session = requests.Session()
self.timeout = timeout

def _extract_embedding_from_response(self, resp_json: Any):
# Try common response shapes
if isinstance(resp_json, dict):
# Handles singular (single vector) response
if "embedding" in resp_json and isinstance(resp_json["embedding"], list):
return resp_json["embedding"]
# Handles plural (batch) response
if "embeddings" in resp_json and isinstance(resp_json["embeddings"], list):
# could be list of lists or a single list
embeddings = resp_json["embeddings"]
if len(embeddings) == 0:
return []
# If embeddings is list of numbers (single embedding), return it
if all(isinstance(x, (float, int)) for x in embeddings):
return embeddings
# else list of embeddings -> take first if single requested
if isinstance(embeddings[0], list):
return embeddings[0]
if "data" in resp_json and isinstance(resp_json["data"], list):
first = resp_json["data"][0]
if isinstance(first, dict) and "embedding" in first:
return first["embedding"]
# fallback: sometimes a key contains the vector directly
for v in resp_json.values():
if (
isinstance(v, list)
and v
and all(isinstance(x, (float, int)) for x in v)
):
return v
# If response is top-level list of dicts like [{'embedding': [...]}, ...]
if isinstance(resp_json, list) and resp_json and isinstance(resp_json[0], dict):
if "embedding" in resp_json[0]:
return resp_json[0]["embedding"]
raise ConfluenceModelError(
"Unexpected Ollama embeddings response format: "
+ (str(resp_json)[:400] if resp_json is not None else "None")
)

def _embed_single(self, text: str) -> List[float]:
payload = {"model": self.model, "prompt": text}
try:
r = self.session.post(self.endpoint, json=payload, timeout=self.timeout)
r.raise_for_status()
except requests.RequestException as e:
raise ConfluenceModelError(
f"Failed to call Ollama embeddings API at {self.endpoint}: {e}"
)
try:
resp_json = r.json()
except ValueError:
raise ConfluenceModelError(
f"Ollama embeddings endpoint returned non-JSON response: {r.text[:400]}"
)
return self._extract_embedding_from_response(resp_json)

def encode(self, texts):
"""
Accepts either a single string or an iterable of strings.
Returns:
- for single string: 1D numpy array (dim,)
- for list of strings: 2D numpy array (n_texts, dim)
"""
if isinstance(texts, str):
emb = self._embed_single(texts)
return np.asarray(emb, dtype=float)
# assume iterable
embeddings = []
for t in texts:
embeddings.append(self._embed_single(t))
return np.asarray(embeddings, dtype=float)


class DenseRetriever:
"""Semantic search using document embeddings"""

def __init__(
self,
embedding_model: SentenceTransformer,
embedding_model: Any, # can be SentenceTransformer or OllamaEmbeddingClient or any object with encode(...)
num_results: int = DEFAULT_TOP_K,
similarity_threshold: float = DEFAULT_RELEVANCE_THRESHOLD,
batch_size: int = BATCH_SIZE,
Expand All @@ -318,7 +413,11 @@ def add_documents(self, documents: List[Document]):
for i in range(0, len(documents), self.batch_size):
batch = documents[i : i + self.batch_size]
batch_texts = [doc.page_content for doc in batch]
# embedding_model.encode should accept list and return 2D array
batch_embeddings = self.embedding_model.encode(batch_texts)
# ensure numpy array
if not isinstance(batch_embeddings, np.ndarray):
batch_embeddings = np.asarray(batch_embeddings, dtype=float)
all_embeddings.append(batch_embeddings)

# Concatenate all batches
Expand All @@ -327,7 +426,8 @@ def add_documents(self, documents: List[Document]):
)

# Create KNN index
self.knn = NearestNeighbors(n_neighbors=min(self.num_results, len(documents)))
n_neighbors = min(self.num_results, len(documents)) if len(documents) > 0 else 1
self.knn = NearestNeighbors(n_neighbors=n_neighbors)
if len(self.document_embeddings) > 0:
self.knn.fit(self.document_embeddings)

Expand All @@ -337,6 +437,8 @@ def get_relevant_documents(self, query: str) -> List[Document]:
return []

query_embedding = self.embedding_model.encode(query)
if not isinstance(query_embedding, np.ndarray):
query_embedding = np.asarray(query_embedding, dtype=float)

_, neighbor_indices = self.knn.kneighbors(query_embedding.reshape(1, -1))
neighbor_indices = neighbor_indices.squeeze(0)
Expand Down Expand Up @@ -456,6 +558,9 @@ def __init__(
device: str = "cpu",
embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
batch_size: int = BATCH_SIZE,
embedding_backend: str = RAG_EMBEDDING_BACKEND,
ollama_host: str = OLLAMA_HOST,
ollama_model: str = OLLAMA_MODEL,
):
self.device = device
self.model_cache_dir = model_cache_dir
Expand All @@ -465,23 +570,52 @@ def __init__(
self.text_splitter = TextSplitter(
chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP
)
self.embedding_backend = (
embedding_backend.lower() if embedding_backend else "sentence_transformers"
)
self.ollama_host = ollama_host
self.ollama_model = ollama_model

async def load_embedding_model(self, event_emitter):
"""Load the embedding model for semantic search"""
await event_emitter.emit_status(
f"Loading embedding model {self.embedding_model_name}...", False
f"Loading embedding model (backend={self.embedding_backend})...", False
)

def load_model():
def load_sentence_transformer():
return SentenceTransformer(
self.embedding_model_name,
cache_folder=self.model_cache_dir,
device=self.device,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)

# Run in an executor to avoid blocking the event loop
self.embedding_model = await asyncio.to_thread(load_model)
# Choose backend
try:
if self.embedding_backend in ("ollama", "ollama_api", "ollama_embedding"):
# instantiate OllamaEmbeddingClient
# Wrap this in a lambda to maintain same asynchronous pattern
def load_ollama():
client = OllamaEmbeddingClient(
host=self.ollama_host, model=self.ollama_model
)
# quick test call for health (optional): skip to keep it lightweight
return client

self.embedding_model = await asyncio.to_thread(load_ollama)
else:
# fallback to sentence transformers
self.embedding_model = await asyncio.to_thread(
load_sentence_transformer
)

await event_emitter.emit_status(
f"Embedding model loaded (backend={self.embedding_backend}).", False
)
except Exception as e:
raise ConfluenceModelError(
f"Failed to load embedding model ({self.embedding_backend}): {e}"
)

return self.embedding_model

Expand Down Expand Up @@ -584,7 +718,9 @@ def get(self, endpoint: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""Make a GET request to the Confluence API"""
url = f"{self.base_url}/rest/api/{endpoint}"
try:
response = requests.get(url, params=params, headers=self.headers, verify=self.ssl_verify)
response = requests.get(
url, params=params, headers=self.headers, verify=self.ssl_verify
)
if response.status_code == 401:
raise ConfluenceAuthError(
"Authentication failed. Check your credentials."
Expand Down Expand Up @@ -699,7 +835,7 @@ class Valves(BaseModel):
description="The base URL of your Confluence instance",
)
ssl_verify: bool = Field(
True,
True,
description="SSL verification"
)
username: str = Field(
Expand All @@ -718,11 +854,26 @@ class Valves(BaseModel):
DEFAULT_MODEL_CACHE_DIR,
description="Path to the folder in which embedding models will be saved",
)
embedding_backend: str = Field(
default=RAG_EMBEDDING_BACKEND,
description='Which embedding backend to use. Allowed values: "ollama", "sentence_transformers".',
)
embedding_model_name: str = Field(
DEFAULT_EMBEDDING_MODEL,
description="Name or path of the embedding model to use",
description="Name or path of the embedding model to use (used for sentence-transformers backend).",
)
ollama_host: str = Field(
default=OLLAMA_HOST,
description="Base URL of local Ollama server (used when embedding_backend='ollama')",
)
ollama_model_name: str = Field(
default=OLLAMA_MODEL,
description="Name of the Ollama embedding model (used when embedding_backend='ollama')",
)
cpu_only: bool = Field(
default=True,
description="Run the tool on CPU only"
)
cpu_only: bool = Field(default=True, description="Run the tool on CPU only")
chunk_size: int = Field(
default=DEFAULT_CHUNK_SIZE,
description="Max. chunk size for Confluence pages",
Expand Down Expand Up @@ -837,12 +988,24 @@ async def search_confluence(
api_username = user_valves.username or self.valves.username
api_key = user_valves.api_key or self.valves.api_key
split_terms = user_valves.split_terms
included_confluence_spaces = user_valves.included_confluence_spaces.split(",") if user_valves.included_confluence_spaces else None
included_confluence_spaces = (
user_valves.included_confluence_spaces.split(",")
if user_valves.included_confluence_spaces
else None
)
if included_confluence_spaces:
included_confluence_spaces = [space.strip() for space in included_confluence_spaces]
excluded_confluence_spaces = user_valves.excluded_confluence_spaces.split(",") if user_valves.excluded_confluence_spaces else None
included_confluence_spaces = [
space.strip() for space in included_confluence_spaces
]
excluded_confluence_spaces = (
user_valves.excluded_confluence_spaces.split(",")
if user_valves.excluded_confluence_spaces
else None
)
if excluded_confluence_spaces:
excluded_confluence_spaces = [space.strip() for space in excluded_confluence_spaces]
excluded_confluence_spaces = [
space.strip() for space in excluded_confluence_spaces
]
else:
api_username = self.valves.username
api_key = self.valves.api_key
Expand Down Expand Up @@ -877,22 +1040,32 @@ async def search_confluence(
# Initialize document retriever and load model with proper error handling
try:
if not self.document_retriever:
# pass embedding backend and ollama info from valves
self.document_retriever = ConfluenceDocumentRetriever(
model_cache_dir=model_cache_dir,
device="cpu" if self.valves.cpu_only else "cuda",
embedding_model_name=self.valves.embedding_model_name,
batch_size=BATCH_SIZE,
embedding_backend=self.valves.embedding_backend,
ollama_host=self.valves.ollama_host,
ollama_model=self.valves.ollama_model_name,
)

if not self.document_retriever.embedding_model:
await self.document_retriever.load_embedding_model(event_emitter)
try:
await self.document_retriever.load_embedding_model(event_emitter)
except ConfluenceModelError as e:
await event_emitter.emit_status(
f"Error loading embedding model: {str(e)}", True, True
)
return f"Error: Failed to load embedding model: {str(e)}"
except Exception as e:
await event_emitter.emit_status(
f"Error loading embedding model: {str(e)}", True, True
f"Error initializing document retriever: {str(e)}", True, True
)
return f"Error: Failed to load embedding model: {str(e)}"
return f"Error: Failed to initialize document retriever: {str(e)}"

# Create Confluence client with proper error handling
# Create Confluence client with error handling
try:
confluence = Confluence(
username=api_username,
Expand Down