diff --git a/README.md b/README.md index 3f5cddf..762c131 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,10 @@ This tool searches for Confluence pages using a given query string with RAG (Ret - Result Limit: Maximum number of pages to retrieve from Confluence API - RAG Settings: - embedding_model_save_path: Path to save embedding models - - embedding_model_name: Name/path of the embedding model to use + - embedding_backend: Which embedding backend to use. Allowed values: "ollama", "sentence_transformers" + - embedding_model_name: Name/path of the embedding model to use (used for sentence-transformers backend) + - ollama_host: Base URL of local Ollama server (used when embedding_backend='ollama') + - ollama_model_name: Name of the Ollama embedding model (used when embedding_backend='ollama') - cpu_only: Run the tool on CPU only (vs GPU) - chunk_size: Maximum size of each content chunk for processing - chunk_overlap: Overlap between consecutive chunks diff --git a/confluence_search.py b/confluence_search.py index 6add9de..7ca47f1 100644 --- a/confluence_search.py +++ b/confluence_search.py @@ -6,7 +6,7 @@ author_url: https://github.com/RomainNeup funding_url: https://github.com/sponsors/RomainNeup requirements: markdownify, sentence-transformers, numpy, rank_bm25, scikit-learn -version: 0.4.0 +version: 0.5.0 changelog: - 0.0.1 - Initial code base. - 0.0.2 - Fix Valves variables @@ -24,6 +24,7 @@ - 0.2.6 - Add terms splitting option - 0.3.0 - Add settings for ssl verification - 0.4.0 - Add support for included/exluded confluence spaces in user settings +- 0.5.0 - Add optional Ollama embedding backend support (local Ollama instance) """ import base64 @@ -41,6 +42,11 @@ from rank_bm25 import BM25Okapi from sklearn.neighbors import NearestNeighbors +# New environment variables for Ollama / backend selection +RAG_EMBEDDING_BACKEND = os.environ.get("RAG_EMBEDDING_BACKEND", "ollama").lower() +OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://127.0.0.1:11434") #change address to match you're ollama instance +OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "snowflake-arctic-embed2:568m") #change model to the one you use + # Get environment variables DEFAULT_EMBEDDING_MODEL = os.environ.get( "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" @@ -291,12 +297,101 @@ def filter_similar_embeddings( return list(sorted(included_idxs)) +class OllamaEmbeddingClient: + """ + Minimal client for Ollama embedding endpoint. + Uses POST {host}/api/embeddings with JSON: {"model": "", "prompt": ""} + It returns a numpy array of shape (n_texts, dim) for list inputs, and 1D array for single string. + The client parses common response shapes: {"embedding": [...]}, {"embeddings": [...]}, {"data": [{"embedding": [...]}]} + """ + + def __init__(self, host: str = OLLAMA_HOST, model: str = OLLAMA_MODEL, timeout: int = 30): + self.base_url = host.rstrip("/") + # endpoint per Ollama docs/examples + self.endpoint = f"{self.base_url}/api/embeddings" + self.model = model + self.session = requests.Session() + self.timeout = timeout + + def _extract_embedding_from_response(self, resp_json: Any): + # Try common response shapes + if isinstance(resp_json, dict): + # Handles singular (single vector) response + if "embedding" in resp_json and isinstance(resp_json["embedding"], list): + return resp_json["embedding"] + # Handles plural (batch) response + if "embeddings" in resp_json and isinstance(resp_json["embeddings"], list): + # could be list of lists or a single list + embeddings = resp_json["embeddings"] + if len(embeddings) == 0: + return [] + # If embeddings is list of numbers (single embedding), return it + if all(isinstance(x, (float, int)) for x in embeddings): + return embeddings + # else list of embeddings -> take first if single requested + if isinstance(embeddings[0], list): + return embeddings[0] + if "data" in resp_json and isinstance(resp_json["data"], list): + first = resp_json["data"][0] + if isinstance(first, dict) and "embedding" in first: + return first["embedding"] + # fallback: sometimes a key contains the vector directly + for v in resp_json.values(): + if ( + isinstance(v, list) + and v + and all(isinstance(x, (float, int)) for x in v) + ): + return v + # If response is top-level list of dicts like [{'embedding': [...]}, ...] + if isinstance(resp_json, list) and resp_json and isinstance(resp_json[0], dict): + if "embedding" in resp_json[0]: + return resp_json[0]["embedding"] + raise ConfluenceModelError( + "Unexpected Ollama embeddings response format: " + + (str(resp_json)[:400] if resp_json is not None else "None") + ) + + def _embed_single(self, text: str) -> List[float]: + payload = {"model": self.model, "prompt": text} + try: + r = self.session.post(self.endpoint, json=payload, timeout=self.timeout) + r.raise_for_status() + except requests.RequestException as e: + raise ConfluenceModelError( + f"Failed to call Ollama embeddings API at {self.endpoint}: {e}" + ) + try: + resp_json = r.json() + except ValueError: + raise ConfluenceModelError( + f"Ollama embeddings endpoint returned non-JSON response: {r.text[:400]}" + ) + return self._extract_embedding_from_response(resp_json) + + def encode(self, texts): + """ + Accepts either a single string or an iterable of strings. + Returns: + - for single string: 1D numpy array (dim,) + - for list of strings: 2D numpy array (n_texts, dim) + """ + if isinstance(texts, str): + emb = self._embed_single(texts) + return np.asarray(emb, dtype=float) + # assume iterable + embeddings = [] + for t in texts: + embeddings.append(self._embed_single(t)) + return np.asarray(embeddings, dtype=float) + + class DenseRetriever: """Semantic search using document embeddings""" def __init__( self, - embedding_model: SentenceTransformer, + embedding_model: Any, # can be SentenceTransformer or OllamaEmbeddingClient or any object with encode(...) num_results: int = DEFAULT_TOP_K, similarity_threshold: float = DEFAULT_RELEVANCE_THRESHOLD, batch_size: int = BATCH_SIZE, @@ -318,7 +413,11 @@ def add_documents(self, documents: List[Document]): for i in range(0, len(documents), self.batch_size): batch = documents[i : i + self.batch_size] batch_texts = [doc.page_content for doc in batch] + # embedding_model.encode should accept list and return 2D array batch_embeddings = self.embedding_model.encode(batch_texts) + # ensure numpy array + if not isinstance(batch_embeddings, np.ndarray): + batch_embeddings = np.asarray(batch_embeddings, dtype=float) all_embeddings.append(batch_embeddings) # Concatenate all batches @@ -327,7 +426,8 @@ def add_documents(self, documents: List[Document]): ) # Create KNN index - self.knn = NearestNeighbors(n_neighbors=min(self.num_results, len(documents))) + n_neighbors = min(self.num_results, len(documents)) if len(documents) > 0 else 1 + self.knn = NearestNeighbors(n_neighbors=n_neighbors) if len(self.document_embeddings) > 0: self.knn.fit(self.document_embeddings) @@ -337,6 +437,8 @@ def get_relevant_documents(self, query: str) -> List[Document]: return [] query_embedding = self.embedding_model.encode(query) + if not isinstance(query_embedding, np.ndarray): + query_embedding = np.asarray(query_embedding, dtype=float) _, neighbor_indices = self.knn.kneighbors(query_embedding.reshape(1, -1)) neighbor_indices = neighbor_indices.squeeze(0) @@ -456,6 +558,9 @@ def __init__( device: str = "cpu", embedding_model_name: str = DEFAULT_EMBEDDING_MODEL, batch_size: int = BATCH_SIZE, + embedding_backend: str = RAG_EMBEDDING_BACKEND, + ollama_host: str = OLLAMA_HOST, + ollama_model: str = OLLAMA_MODEL, ): self.device = device self.model_cache_dir = model_cache_dir @@ -465,14 +570,19 @@ def __init__( self.text_splitter = TextSplitter( chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP ) + self.embedding_backend = ( + embedding_backend.lower() if embedding_backend else "sentence_transformers" + ) + self.ollama_host = ollama_host + self.ollama_model = ollama_model async def load_embedding_model(self, event_emitter): """Load the embedding model for semantic search""" await event_emitter.emit_status( - f"Loading embedding model {self.embedding_model_name}...", False + f"Loading embedding model (backend={self.embedding_backend})...", False ) - def load_model(): + def load_sentence_transformer(): return SentenceTransformer( self.embedding_model_name, cache_folder=self.model_cache_dir, @@ -480,8 +590,32 @@ def load_model(): trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) - # Run in an executor to avoid blocking the event loop - self.embedding_model = await asyncio.to_thread(load_model) + # Choose backend + try: + if self.embedding_backend in ("ollama", "ollama_api", "ollama_embedding"): + # instantiate OllamaEmbeddingClient + # Wrap this in a lambda to maintain same asynchronous pattern + def load_ollama(): + client = OllamaEmbeddingClient( + host=self.ollama_host, model=self.ollama_model + ) + # quick test call for health (optional): skip to keep it lightweight + return client + + self.embedding_model = await asyncio.to_thread(load_ollama) + else: + # fallback to sentence transformers + self.embedding_model = await asyncio.to_thread( + load_sentence_transformer + ) + + await event_emitter.emit_status( + f"Embedding model loaded (backend={self.embedding_backend}).", False + ) + except Exception as e: + raise ConfluenceModelError( + f"Failed to load embedding model ({self.embedding_backend}): {e}" + ) return self.embedding_model @@ -584,7 +718,9 @@ def get(self, endpoint: str, params: Dict[str, Any]) -> Dict[str, Any]: """Make a GET request to the Confluence API""" url = f"{self.base_url}/rest/api/{endpoint}" try: - response = requests.get(url, params=params, headers=self.headers, verify=self.ssl_verify) + response = requests.get( + url, params=params, headers=self.headers, verify=self.ssl_verify + ) if response.status_code == 401: raise ConfluenceAuthError( "Authentication failed. Check your credentials." @@ -699,7 +835,7 @@ class Valves(BaseModel): description="The base URL of your Confluence instance", ) ssl_verify: bool = Field( - True, + True, description="SSL verification" ) username: str = Field( @@ -718,11 +854,26 @@ class Valves(BaseModel): DEFAULT_MODEL_CACHE_DIR, description="Path to the folder in which embedding models will be saved", ) + embedding_backend: str = Field( + default=RAG_EMBEDDING_BACKEND, + description='Which embedding backend to use. Allowed values: "ollama", "sentence_transformers".', + ) embedding_model_name: str = Field( DEFAULT_EMBEDDING_MODEL, - description="Name or path of the embedding model to use", + description="Name or path of the embedding model to use (used for sentence-transformers backend).", + ) + ollama_host: str = Field( + default=OLLAMA_HOST, + description="Base URL of local Ollama server (used when embedding_backend='ollama')", + ) + ollama_model_name: str = Field( + default=OLLAMA_MODEL, + description="Name of the Ollama embedding model (used when embedding_backend='ollama')", + ) + cpu_only: bool = Field( + default=True, + description="Run the tool on CPU only" ) - cpu_only: bool = Field(default=True, description="Run the tool on CPU only") chunk_size: int = Field( default=DEFAULT_CHUNK_SIZE, description="Max. chunk size for Confluence pages", @@ -837,12 +988,24 @@ async def search_confluence( api_username = user_valves.username or self.valves.username api_key = user_valves.api_key or self.valves.api_key split_terms = user_valves.split_terms - included_confluence_spaces = user_valves.included_confluence_spaces.split(",") if user_valves.included_confluence_spaces else None + included_confluence_spaces = ( + user_valves.included_confluence_spaces.split(",") + if user_valves.included_confluence_spaces + else None + ) if included_confluence_spaces: - included_confluence_spaces = [space.strip() for space in included_confluence_spaces] - excluded_confluence_spaces = user_valves.excluded_confluence_spaces.split(",") if user_valves.excluded_confluence_spaces else None + included_confluence_spaces = [ + space.strip() for space in included_confluence_spaces + ] + excluded_confluence_spaces = ( + user_valves.excluded_confluence_spaces.split(",") + if user_valves.excluded_confluence_spaces + else None + ) if excluded_confluence_spaces: - excluded_confluence_spaces = [space.strip() for space in excluded_confluence_spaces] + excluded_confluence_spaces = [ + space.strip() for space in excluded_confluence_spaces + ] else: api_username = self.valves.username api_key = self.valves.api_key @@ -877,22 +1040,32 @@ async def search_confluence( # Initialize document retriever and load model with proper error handling try: if not self.document_retriever: + # pass embedding backend and ollama info from valves self.document_retriever = ConfluenceDocumentRetriever( model_cache_dir=model_cache_dir, device="cpu" if self.valves.cpu_only else "cuda", embedding_model_name=self.valves.embedding_model_name, batch_size=BATCH_SIZE, + embedding_backend=self.valves.embedding_backend, + ollama_host=self.valves.ollama_host, + ollama_model=self.valves.ollama_model_name, ) if not self.document_retriever.embedding_model: - await self.document_retriever.load_embedding_model(event_emitter) + try: + await self.document_retriever.load_embedding_model(event_emitter) + except ConfluenceModelError as e: + await event_emitter.emit_status( + f"Error loading embedding model: {str(e)}", True, True + ) + return f"Error: Failed to load embedding model: {str(e)}" except Exception as e: await event_emitter.emit_status( - f"Error loading embedding model: {str(e)}", True, True + f"Error initializing document retriever: {str(e)}", True, True ) - return f"Error: Failed to load embedding model: {str(e)}" + return f"Error: Failed to initialize document retriever: {str(e)}" - # Create Confluence client with proper error handling + # Create Confluence client with error handling try: confluence = Confluence( username=api_username,