diff --git a/.gitignore b/.gitignore
index 4e9fd655..9fa693f4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -68,4 +68,3 @@ spring_ai/drop.sql
src/client/spring_ai/target/classes/*
api_server_key
.env
-
diff --git a/src/client/content/chatbot.py.mcp b/src/client/content/chatbot.py.mcp
new file mode 100644
index 00000000..c79146f7
--- /dev/null
+++ b/src/client/content/chatbot.py.mcp
@@ -0,0 +1,300 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+
+This file merges the Streamlit Chatbot GUI with the MCPClient for a complete,
+runnable example demonstrating their integration.
+"""
+
+# spell-checker:ignore streamlit, oraclevs, selectai, langgraph, prebuilt
+import asyncio
+import inspect
+import json
+import base64
+
+import streamlit as st
+from streamlit import session_state as state
+
+from client.content.config.tabs.models import get_models
+
+import client.utils.st_common as st_common
+import client.utils.api_call as api_call
+
+from client.utils.st_footer import render_chat_footer
+import common.logging_config as logging_config
+from client.mcp.client import MCPClient
+from pathlib import Path
+
+logger = logging_config.logging.getLogger("client.content.chatbot")
+
+
+#############################################################################
+# Functions
+#############################################################################
+def show_vector_search_refs(context):
+ """When Vector Search Content Found, show the references"""
+ st.markdown("**References:**")
+ ref_src = set()
+ ref_cols = st.columns([3, 3, 3])
+ # Create a button in each column
+ for i, (ref_col, chunk) in enumerate(zip(ref_cols, context[0])):
+ with ref_col.popover(f"Reference: {i + 1}"):
+ chunk = context[0][i]
+ logger.debug("Chunk Content: %s", chunk)
+ st.subheader("Reference Text", divider="red")
+ st.markdown(chunk["page_content"])
+ try:
+ ref_src.add(chunk["metadata"]["filename"])
+ st.subheader("Metadata", divider="red")
+ st.markdown(f"File: {chunk['metadata']['source']}")
+ st.markdown(f"Chunk: {chunk['metadata']['page']}")
+ except KeyError:
+ logger.error("Chunk Metadata NOT FOUND!!")
+
+ for link in ref_src:
+ st.markdown("- " + link)
+ st.markdown(f"**Notes:** Vector Search Query - {context[1]}")
+
+
+#############################################################################
+# MAIN
+#############################################################################
+async def main() -> None:
+ """Streamlit GUI"""
+ try:
+ get_models()
+ except api_call.ApiError:
+ st.stop()
+ #########################################################################
+ # Sidebar Settings
+ #########################################################################
+ ll_models_enabled = st_common.enabled_models_lookup("ll")
+ if not ll_models_enabled:
+ st.error("No language models are configured and/or enabled. Disabling Client.", icon="🛑")
+ st.stop()
+ state.enable_client = True
+ st_common.tools_sidebar()
+ st_common.history_sidebar()
+ st_common.ll_sidebar()
+ st_common.selectai_sidebar()
+ st_common.vector_search_sidebar()
+ if not state.enable_client:
+ st.stop()
+
+ #########################################################################
+ # Chatty-Bot Centre
+ #########################################################################
+
+ if "messages" not in state:
+ state.messages = []
+
+ st.chat_message("ai").write("Hello, how can I help you?")
+
+ for message in state.messages:
+ role = message.get("role")
+ display_role = ""
+ if role in ("human", "user"):
+ display_role = "human"
+ elif role in ("ai", "assistant"):
+ if not message.get("content") and not message.get("tool_trace"):
+ continue
+ display_role = "assistant"
+ else:
+ continue
+
+ with st.chat_message(display_role):
+ if "tool_trace" in message and message["tool_trace"]:
+ for tool_call in message["tool_trace"]:
+ with st.expander(f"🛠️ **Tool Call:** `{tool_call['name']}`", expanded=False):
+ st.text("Arguments:")
+ st.code(json.dumps(tool_call.get("args", {}), indent=2), language="json")
+ if "error" in tool_call:
+ st.text("Error:")
+ st.error(tool_call["error"])
+ else:
+ st.text("Result:")
+ st.code(tool_call.get("result", ""), language="json")
+ if message.get("content"):
+ # Display file attachments if present
+ if "attachments" in message and message["attachments"]:
+ for file in message["attachments"]:
+ # Show appropriate icon based on file type
+ if file["type"].startswith("image/"):
+ st.image(file["preview"], use_container_width=True)
+ st.markdown(f"🖼️ **{file['name']}** ({file['size'] // 1024} KB)")
+ elif file["type"] == "application/pdf":
+ st.markdown(f"📄 **{file['name']}** ({file['size'] // 1024} KB)")
+ elif file["type"] in ("text/plain", "text/markdown"):
+ st.markdown(f"📝 **{file['name']}** ({file['size'] // 1024} KB)")
+ else:
+ st.markdown(f"📎 **{file['name']}** ({file['size'] // 1024} KB)")
+
+ # Display message content - handle both string and list formats
+ content = message.get("content")
+ if isinstance(content, list):
+ # Extract and display only text parts
+ text_parts = [part["text"] for part in content if part["type"] == "text"]
+ st.markdown("\n".join(text_parts))
+ else:
+ st.markdown(content)
+
+ sys_prompt = state.client_settings["prompts"]["sys"]
+ render_chat_footer()
+
+ if human_request := st.chat_input(
+ f"Ask your question here... (current prompt: {sys_prompt})",
+ accept_file=True,
+ file_type=["jpg", "jpeg", "png", "pdf", "txt", "docx"],
+ key=f"chat_input_{len(state.messages)}",
+ ):
+ # Process message with potential file attachments
+ message = {"role": "user", "content": human_request.text}
+
+ # Handle file attachments
+ if hasattr(human_request, "files") and human_request.files:
+ # Store file information separately from content
+ message["attachments"] = []
+ for file in human_request.files:
+ file_bytes = file.read()
+ file_b64 = base64.b64encode(file_bytes).decode("utf-8")
+ message["attachments"].append(
+ {
+ "name": file.name,
+ "type": file.type,
+ "size": len(file_bytes),
+ "data": file_b64,
+ "preview": f"data:{file.type};base64,{file_b64}" if file.type.startswith("image/") else None,
+ }
+ )
+
+ state.messages.append(message)
+ st.rerun()
+ if state.messages and state.messages[-1]["role"] == "user":
+ try:
+ with st.chat_message("ai"):
+ with st.spinner("Thinking..."):
+ client_settings_for_request = state.client_settings.copy()
+ model_id = client_settings_for_request.get("ll_model", {}).get("model")
+ if model_id:
+ all_model_configs = st_common.enabled_models_lookup("ll")
+ model_config = all_model_configs.get(model_id, {})
+ if "api_key" in model_config:
+ if "ll_model" not in client_settings_for_request:
+ client_settings_for_request["ll_model"] = {}
+ client_settings_for_request["ll_model"]["api_key"] = model_config["api_key"]
+
+ # Prepare message history for backend
+ message_history = []
+ for msg in state.messages:
+ # Create a copy of the message
+ processed_msg = msg.copy()
+
+ # If there are attachments, include them in the content
+ if "attachments" in msg and msg["attachments"]:
+ # Start with the text content
+ text_content = msg["content"]
+
+ # Handle list content format (from OpenAI API)
+ if isinstance(text_content, list):
+ text_parts = [part["text"] for part in text_content if part["type"] == "text"]
+ text_content = "\n".join(text_parts)
+
+ # Create a list to hold structured content parts
+ content_list = [{"type": "text", "text": text_content}]
+
+ non_image_references = []
+ for attachment in msg["attachments"]:
+ if attachment["type"].startswith("image/"):
+ # Only add image URLs for user messages
+ if msg["role"] in ("human", "user"):
+ # Normalize image MIME types for compatibility
+ mime_type = attachment["type"]
+ if mime_type == "image/jpg":
+ mime_type = "image/jpeg"
+
+ content_list.append(
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:{mime_type};base64,{attachment['data']}",
+ "detail": "low",
+ },
+ }
+ )
+ else:
+ # Handle non-image files as text references
+ non_image_references.append(
+ f"\n[File: {attachment['name']} ({attachment['size'] // 1024} KB)]"
+ )
+
+ # If there were non-image files, append their references to the main text part
+ if non_image_references:
+ content_list[0]["text"] += "".join(non_image_references)
+
+ processed_msg["content"] = content_list
+ # Convert list content to string format
+ elif isinstance(msg.get("content"), list):
+ text_parts = [part["text"] for part in msg["content"] if part["type"] == "text"]
+ processed_msg["content"] = str("\n".join(text_parts))
+ # Otherwise, ensure content is a string
+ else:
+ processed_msg["content"] = str(msg.get("content", ""))
+
+ message_history.append(processed_msg)
+
+ async with MCPClient(client_settings=client_settings_for_request) as mcp_client:
+ final_text, tool_trace, new_history = await mcp_client.invoke(message_history=message_history)
+
+ # Update the history for display.
+ # Keep the original message structure with attachments
+ for i in range(len(new_history) - 1, -1, -1):
+ if new_history[i].get("role") == "assistant":
+ # Preserve any attachments from the user message
+ user_message = state.messages[-1]
+ if "attachments" in user_message:
+ new_history[-1]["attachments"] = user_message["attachments"]
+
+ new_history[i]["content"] = final_text
+ new_history[i]["tool_trace"] = tool_trace
+ break
+
+ state.messages = new_history
+ st.rerun()
+
+ except Exception as e:
+ logger.error("Exception during invoke call:", exc_info=True)
+ # Extract just the error message
+ error_msg = str(e)
+
+ # Check if it's a file-related error
+ if "file" in error_msg.lower() or "image" in error_msg.lower() or "content" in error_msg.lower():
+ st.error(f"Error: {error_msg}")
+
+ # Add a button to remove files and retry
+ if st.button("Remove files and retry", key="remove_files_retry"):
+ # Remove attachments from the latest message
+ if state.messages and "attachments" in state.messages[-1]:
+ del state.messages[-1]["attachments"]
+ st.rerun()
+ else:
+ st.error(f"Error: {error_msg}")
+
+ if st.button("Retry", key="reload_chatbot_error"):
+ if state.messages and state.messages[-1]["role"] == "user":
+ state.messages.pop()
+ st.rerun()
+
+
+if __name__ == "__main__" or ("page" in inspect.stack()[1].filename if inspect.stack() else False):
+ try:
+ asyncio.run(main())
+ except ValueError as ex:
+ logger.exception("Bug detected: %s", ex)
+ st.error("It looks like you found a bug; please open an issue", icon="🛑")
+ st.stop()
+ except IndexError as ex:
+ logger.exception("Unable to contact the server: %s", ex)
+ st.error("Unable to contact the server, is it running?", icon="🚨")
+ if st.button("Retry", key="reload_chatbot"):
+ st_common.clear_state_key("user_client")
+ st.rerun()
diff --git a/src/client/content/config/config.py b/src/client/content/config/config.py
index 84a664ba..98a1a8f1 100644
--- a/src/client/content/config/config.py
+++ b/src/client/content/config/config.py
@@ -11,6 +11,7 @@
from client.content.config.tabs.oci import get_oci, display_oci
from client.content.config.tabs.databases import get_databases, display_databases
from client.content.config.tabs.models import get_models, display_models
+from client.content.config.tabs.mcp import get_mcp, display_mcp
def main() -> None:
@@ -20,6 +21,7 @@ def main() -> None:
get_databases()
get_models()
get_oci()
+ get_mcp()
tabs_list = []
if not state.disabled["settings"]:
@@ -30,6 +32,8 @@ def main() -> None:
tabs_list.append("🤖 Models")
if not state.disabled["oci_cfg"]:
tabs_list.append("☁️ OCI")
+ if not state.disabled["mcp_cfg"]:
+ tabs_list.append("🔗 MCP")
# Only create tabs if there is at least one
tab_index = 0
@@ -53,6 +57,10 @@ def main() -> None:
with tabs[tab_index]:
display_oci()
tab_index += 1
+ if not state.disabled["mcp_cfg"]:
+ with tabs[tab_index]:
+ display_mcp()
+ tab_index += 1
if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename:
diff --git a/src/client/content/config/tabs/databases.py b/src/client/content/config/tabs/databases.py
index 3bf4ee32..cab6b596 100644
--- a/src/client/content/config/tabs/databases.py
+++ b/src/client/content/config/tabs/databases.py
@@ -26,7 +26,6 @@ def get_databases(force: bool = False) -> None:
"""Get Databases from API Server"""
if force or "database_configs" not in state or not state.database_configs:
try:
- logger.info("Refreshing state.database_configs")
# Validation will be done on currently configured client database
# validation includes new vector_stores, etc.
client_database = state.client_settings.get("database", {}).get("alias", {})
diff --git a/src/client/content/config/tabs/mcp.py b/src/client/content/config/tabs/mcp.py
new file mode 100644
index 00000000..75621ba5
--- /dev/null
+++ b/src/client/content/config/tabs/mcp.py
@@ -0,0 +1,188 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+
+# spell-checker:ignore selectbox healthz
+import json
+
+import streamlit as st
+from streamlit import session_state as state
+
+from client.utils import api_call, st_common
+
+from common import logging_config
+
+logger = logging_config.logging.getLogger("client.content.config.tabs.mcp")
+
+
+###################################
+# Functions
+###################################
+def get_mcp_status() -> dict:
+ """Get MCP Status"""
+ try:
+ return api_call.get(endpoint="v1/mcp/healthz")
+ except api_call.ApiError as ex:
+ logger.error("Unable to get MCP Status: %s", ex)
+ return {}
+
+
+def get_mcp_client() -> dict:
+ """Get MCP Client Configuration"""
+ try:
+ params = {"server": {state.server["url"]}, "port": {state.server["port"]}}
+ mcp_client = api_call.get(endpoint="v1/mcp/client", params=params)
+ return json.dumps(mcp_client, indent=2)
+ except api_call.ApiError as ex:
+ logger.error("Unable to get MCP Client: %s", ex)
+ return {}
+
+
+def get_mcp(force: bool = False) -> list[dict]:
+ """Get MCP configs from API Server"""
+ if force or "mcp_configs" not in state or not state.mcp_configs:
+ logger.info("Refreshing state.mcp_configs")
+ endpoints = {
+ "tools": "v1/mcp/tools",
+ "prompts": "v1/mcp/prompts",
+ "resources": "v1/mcp/resources",
+ }
+ results = {}
+
+ for key, endpoint in endpoints.items():
+ try:
+ results[key] = api_call.get(endpoint=endpoint)
+ except api_call.ApiError as ex:
+ logger.error("Unable to get %s: %s", key, ex)
+ results[key] = {}
+
+ state.mcp_configs = results
+
+
+def extract_servers() -> list:
+ """Get a list of distinct MCP servers (by prefix)"""
+ prefixes = set()
+
+ for _, items in state.mcp_configs.items():
+ for item in items or []: # handle None safely
+ name = item.get("name")
+ if name and "_" in name:
+ prefix = name.split("_", 1)[0]
+ prefixes.add(prefix)
+
+ mcp_servers = sorted(prefixes)
+
+ if "optimizer" in mcp_servers:
+ mcp_servers.remove("optimizer")
+ mcp_servers.insert(0, "optimizer")
+
+ return mcp_servers
+
+
+@st.dialog(title="Details", width="large")
+def mcp_details(mcp_server: str, mcp_type: str, mcp_name: str) -> None:
+ """MCP Dialog Box"""
+ st.header(f"{mcp_name} - MCP server: {mcp_server}")
+ config = next((t for t in state.mcp_configs[mcp_type] if t.get("name") == f"{mcp_server}_{mcp_name}"), None)
+ if config.get("description"):
+ st.code(config["description"], wrap_lines=True, height="content")
+ if config.get("inputSchema"):
+ st.subheader("inputSchema", divider="red")
+ properties = config["inputSchema"].get("properties", {})
+ required_fields = set(config["inputSchema"].get("required", []))
+ for name, prop in properties.items():
+ req = '(required)' if name in required_fields else ""
+ html = f"""
+
{name} {req}
+
+ - Description: {prop.get("description", "")}
+ - Type: {prop.get("type", "any")}
+ - Default: {prop.get("default", "None")}
+
+ """
+ st.html(html)
+ if config.get("outputSchema"):
+ st.subheader("outputSchema", divider="red")
+ if config.get("arguments"):
+ st.subheader("arguments", divider="red")
+ if config.get("annotations"):
+ st.subheader("annotations", divider="red")
+ if config.get("meta"):
+ st.subheader("meta", divider="red")
+
+
+def render_configs(mcp_server: str, mcp_type: str, configs: list) -> None:
+ """Render rows of the MCP type"""
+ data_col_widths = [0.8, 0.2]
+ table_col_format = st.columns(data_col_widths, vertical_alignment="center")
+ col1, col2 = table_col_format
+ col1.markdown("Name", unsafe_allow_html=True)
+ col2.markdown("")
+ for mcp_name in configs:
+ col1.text_input(
+ "Name",
+ value=mcp_name,
+ label_visibility="collapsed",
+ disabled=True,
+ )
+ col2.button(
+ "Details",
+ on_click=mcp_details,
+ key=f"{mcp_server}_{mcp_name}_details",
+ kwargs=dict(mcp_server=mcp_server, mcp_type=mcp_type, mcp_name=mcp_name),
+ )
+
+
+#############################################################################
+# MAIN
+#############################################################################
+def display_mcp() -> None:
+ """Streamlit GUI"""
+ st.header("Model Context Protocol", divider="red")
+ try:
+ get_mcp()
+ except api_call.ApiError:
+ st.stop()
+ mcp_status = get_mcp_status()
+ if mcp_status.get("status") == "ready":
+ st.markdown(f"""
+ The {mcp_status["name"]} is running.
+ **Version**: {mcp_status["version"]}
+ """)
+ with st.expander("Client Configuration"):
+ st.code(get_mcp_client(), language="json")
+ else:
+ st.error("MCP Server is not running!", icon="🛑")
+ st.stop()
+
+ selected_mcp_server = st.selectbox(
+ "Configured MCP Server(s):",
+ options=extract_servers(),
+ # index=list(database_lookup.keys()).index(state.client_settings["database"]["alias"]),
+ key="selected_mcp_server",
+ # on_change=st_common.update_client_settings("database"),
+ )
+ if state.mcp_configs["tools"]:
+ tools_lookup = st_common.state_configs_lookup("mcp_configs", "name", "tools")
+ mcp_tools = [key.split("_", 1)[1] for key in tools_lookup if key.startswith(f"{selected_mcp_server}_")]
+ if mcp_tools:
+ st.subheader("Tools", divider="red")
+ render_configs(selected_mcp_server, "tools", mcp_tools)
+ if state.mcp_configs["prompts"]:
+ prompts_lookup = st_common.state_configs_lookup("mcp_configs", "name", "prompts")
+ mcp_prompts = [key.split("_", 1)[1] for key in prompts_lookup if key.startswith(f"{selected_mcp_server}_")]
+ if mcp_prompts:
+ st.subheader("Prompts", divider="red")
+ render_configs(selected_mcp_server, "prompts", mcp_prompts)
+ if state.mcp_configs["resources"]:
+ st.subheader("Resources", divider="red")
+ resources_lookup = st_common.state_configs_lookup("mcp_configs", "name", "resources")
+ mcp_resources = [key.split("_", 1)[1] for key in resources_lookup if key.startswith(f"{selected_mcp_server}_")]
+ if mcp_resources:
+ st.subheader("Resources", divider="red")
+ render_configs(selected_mcp_server, "resources", mcp_resources)
+
+
+if __name__ == "__main__":
+ display_mcp()
diff --git a/src/client/content/config/tabs/mcp_bak.py b/src/client/content/config/tabs/mcp_bak.py
new file mode 100644
index 00000000..89450432
--- /dev/null
+++ b/src/client/content/config/tabs/mcp_bak.py
@@ -0,0 +1,24 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+
+from client.mcp.frontend import display_commands_tab, display_ide_tab, get_fastapi_base_url, get_server_capabilities
+
+import streamlit as st
+
+def display_mcp():
+ fastapi_base_url = get_fastapi_base_url()
+ tools, resources, prompts = get_server_capabilities(fastapi_base_url)
+ if "chat_history" not in st.session_state:
+ st.session_state.chat_history = []
+
+
+ ide, commands = st.tabs(["🛠️ IDE", "📚 Available Commands"])
+
+ with ide:
+ # Display the IDE tab using the original AI Optimizer logic.
+ display_ide_tab()
+ with commands:
+ # Display the commands tab using the original AI Optimizer logic.
+ display_commands_tab(tools, resources, prompts)
diff --git a/src/client/mcp/client.py b/src/client/mcp/client.py
new file mode 100644
index 00000000..d4282828
--- /dev/null
+++ b/src/client/mcp/client.py
@@ -0,0 +1,446 @@
+import json
+import os
+import time
+import asyncio
+from dotenv import load_dotenv
+from mcp import ClientSession, StdioServerParameters, types
+from mcp.client.stdio import stdio_client
+from typing import List, Dict, Optional, Tuple, Type, Any
+from contextlib import AsyncExitStack
+
+# --- MODIFICATION: Import LangChain components ---
+from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage
+from langchain_core.language_models.base import BaseLanguageModel
+from pydantic import create_model, BaseModel, Field
+# Import the specific chat models you want to support
+from langchain_openai import ChatOpenAI
+from langchain_anthropic import ChatAnthropic
+from langchain_google_genai import ChatGoogleGenerativeAI
+from langchain_cohere import ChatCohere
+from langchain_ollama import ChatOllama
+from langchain_groq import ChatGroq
+from langchain_mistralai import ChatMistralAI
+
+load_dotenv()
+
+if os.getenv("IS_STREAMLIT_CONTEXT"):
+ import nest_asyncio
+ nest_asyncio.apply()
+
+class MCPClient:
+ # MODIFICATION: Changed the constructor to accept client_settings
+ def __init__(self, client_settings: Dict):
+ """
+ Initialize MCP Client using a settings dictionary from the Streamlit client.
+
+ Args:
+ client_settings: The state.client_settings object.
+ """
+ # 1. Validate the incoming settings dictionary
+ if not client_settings or 'll_model' not in client_settings:
+ raise ValueError("Client settings are incomplete. 'll_model' is required.")
+
+ # 2. Store the settings and extract the model ID
+ self.model_settings = client_settings['ll_model']
+
+ # This is our new "Service Factory" using LangChain classes
+ # If no model is specified, we'll initialize with a default one
+ if 'model' not in self.model_settings or not self.model_settings['model']:
+ # Set a default model if none is specified
+ self.model_settings['model'] = 'llama3.1'
+ # Remove any OpenAI-specific parameters that might cause issues
+ self.model_settings.pop('openai_api_key', None)
+
+ self.langchain_model = self._create_langchain_model(**self.model_settings)
+
+ self.exit_stack = AsyncExitStack()
+ self.sessions: Dict[str, ClientSession] = {}
+ self.tool_to_session: Dict[str, Tuple[ClientSession, types.Tool]] = {}
+ self.available_prompts: Dict[str, types.Prompt] = {}
+ self.static_resources: Dict[str, str] = {}
+ self.dynamic_resources: List[str] = []
+ self.resource_to_session: Dict[str, str] = {}
+ self.prompt_to_session: Dict[str, str] = {}
+ self.available_tools: List[Dict] = []
+ self._stdio_generators: Dict[str, Any] = {} # To store stdio generators for cleanup
+ print(f"Initialized MCPClient with LangChain model: {self.langchain_model.__class__.__name__}")
+
+ # --- FIX: Add __aenter__ and __aexit__ to make this a context manager ---
+ async def __aenter__(self):
+ """Enter the async context, connecting to all servers."""
+ await self.connect_to_servers()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ """Exit the async context, ensuring all connections are cleaned up."""
+ await self.cleanup()
+
+ def _create_langchain_model(self, model: str, **kwargs) -> BaseLanguageModel:
+ """Factory to create and return a LangChain ChatModel instance."""
+ # If no model is specified, default to llama3.1 which works with Ollama
+ if not model:
+ model = "llama3.1"
+ # Remove any OpenAI-specific parameters that might cause issues
+ kwargs.pop('openai_api_key', None)
+
+ model_lower = model.lower()
+
+ # Handle OpenAI models
+ if model_lower.startswith('gpt-'):
+ # Check if api_key is in kwargs and rename it to openai_api_key for ChatOpenAI
+ if 'api_key' in kwargs:
+ kwargs['openai_api_key'] = kwargs.pop('api_key')
+ # Remove parameters that shouldn't be passed to ChatOpenAI
+ kwargs.pop('context_length', None)
+ kwargs.pop('chat_history', None)
+ return ChatOpenAI(model=model, **kwargs)
+
+ # Handle Anthropic models
+ elif model_lower.startswith('claude-'):
+ kwargs.pop('openai_api_key', None)
+ return ChatAnthropic(model=model, **kwargs)
+
+ # Handle Google models
+ elif model_lower.startswith('gemini-'):
+ kwargs.pop('openai_api_key', None)
+ return ChatGoogleGenerativeAI(model=model, **kwargs)
+
+ # Handle Mistral models
+ elif model_lower.startswith('mistral-'):
+ kwargs.pop('openai_api_key', None)
+ return ChatMistralAI(model=model, **kwargs)
+
+ # Handle Cohere models
+ elif model_lower.startswith('cohere-'):
+ kwargs.pop('openai_api_key', None)
+ return ChatCohere(model=model, **kwargs)
+
+ # Handle Groq models
+ elif model_lower.startswith('groq-'):
+ kwargs.pop('openai_api_key', None)
+ return ChatGroq(model=model, **kwargs)
+
+ # Default to Ollama for any other model name
+ else:
+ return ChatOllama(model=model, **kwargs)
+
+ def _convert_dict_to_langchain_messages(self, message_history: List[Dict]) -> List[BaseMessage]:
+ """Converts a list of message dictionaries to a list of LangChain message objects."""
+ messages: List[BaseMessage] = []
+ for msg in message_history:
+ role = msg.get("role")
+ content = msg.get("content", "")
+ if role == "user":
+ messages.append(HumanMessage(content=content)) # type: ignore
+ elif role == "assistant":
+ # AIMessage can handle tool calls directly from the dictionary format
+ tool_calls = msg.get("tool_calls")
+ messages.append(AIMessage(content=content, tool_calls=tool_calls or [])) # type: ignore
+ elif role == "system":
+ messages.append(SystemMessage(content=content)) # type: ignore
+ elif role == "tool":
+ messages.append(ToolMessage(content=content, tool_call_id=msg.get("tool_call_id", ""))) # type: ignore
+ return messages # type: ignore
+
+ def _convert_langchain_messages_to_dict(self, langchain_messages: List[BaseMessage]) -> List[Dict]:
+ """Converts a list of LangChain message objects back to a list of dictionaries for session state."""
+ dict_messages = []
+ for msg in langchain_messages:
+ if isinstance(msg, HumanMessage):
+ dict_messages.append({"role": "user", "content": msg.content})
+ elif isinstance(msg, AIMessage):
+ # Preserve tool calls in the dictionary format
+ dict_messages.append({"role": "assistant", "content": msg.content, "tool_calls": msg.tool_calls})
+ elif isinstance(msg, SystemMessage):
+ dict_messages.append({"role": "system", "content": msg.content})
+ elif isinstance(msg, ToolMessage):
+ dict_messages.append({"role": "tool", "content": msg.content, "tool_call_id": msg.tool_call_id})
+ return dict_messages
+
+ def _prepare_messages_for_service(self, message_history: List[Dict]) -> List[Dict]:
+ """
+ FIX: Translates the rich message history from the GUI into a simple,
+ text-only format that AI services can understand.
+ """
+ prepared_messages = []
+ for msg in message_history:
+ content = msg.get("content")
+ # If content is a list (multimodal), extract only the text.
+ if isinstance(content, list):
+ text_content = " ".join(
+ part["text"] for part in content if part.get("type") == "text"
+ )
+ prepared_messages.append({"role": msg["role"], "content": text_content})
+ # Otherwise, use the content as is (assuming it's a string).
+ else:
+ prepared_messages.append(msg)
+ return prepared_messages
+
+ async def connect_to_servers(self):
+ try:
+ config_paths = ["server/mcp/server_config.json", os.path.join(os.path.dirname(__file__), "..", "..", "server", "mcp", "server_config.json")]
+ servers = {}
+ for config_path in config_paths:
+ try:
+ with open(config_path, "r") as file:
+ servers = json.load(file).get("mcpServers", {})
+ print(f"Loaded MCP server configuration from: {config_path}")
+ print(f"Found servers: {list(servers.keys())}")
+ break
+ except FileNotFoundError:
+ print(f"MCP server config not found at: {config_path}")
+ continue
+ except Exception as e:
+ print(f"Error reading MCP server config from {config_path}: {e}")
+ continue
+ if not servers:
+ print("No MCP server configuration found!")
+ for name, config in servers.items():
+ print(f"Connecting to MCP server: {name}")
+ await self.connect_to_server(name, config)
+ except Exception as e: print(f"Error loading server configuration: {e}")
+
+ async def connect_to_server(self, server_name: str, server_config: dict):
+ try:
+ print(f"Connecting to server '{server_name}' with config: {server_config}")
+ server_params = StdioServerParameters(**server_config)
+
+ # Create the stdio client connection using the exit stack for proper cleanup
+ try:
+ read, write = await self.exit_stack.enter_async_context(stdio_client(server_params))
+
+ # Create the client session using the exit stack for proper cleanup
+ session = await self.exit_stack.enter_async_context(ClientSession(read, write))
+
+ await session.initialize()
+ self.sessions[server_name] = session
+
+ # Load tools, resources, and prompts from this server
+ await self._load_server_capabilities(session, server_name)
+ except RuntimeError as e:
+ # Handle runtime errors related to task context
+ if "cancel scope" not in str(e).lower():
+ raise
+ print(f"Warning: Connection to '{server_name}' had context issues: {e}")
+ except Exception as e:
+ raise
+ except Exception as e:
+ print(f"Failed to connect to '{server_name}': {e}")
+ import traceback
+ traceback.print_exc()
+
+ async def _run_async_generator(self, generator):
+ """Helper method to run an async generator in the current task context."""
+ return await generator.__anext__()
+
+ async def _load_server_capabilities(self, session: ClientSession, server_name: str):
+ """Load tools, resources, and prompts from a connected server."""
+ try:
+ # List tools
+ tools_list = await session.list_tools()
+ print(f"Found {len(tools_list.tools)} tools from server '{server_name}'")
+ for tool in tools_list.tools:
+ self.tool_to_session[tool.name] = (session, tool)
+ print(f"Loaded tool '{tool.name}' from server '{server_name}'")
+
+ # List resources
+ try:
+ resp = await session.list_resources()
+ if resp.resources: print(f" - Found Static Resources: {[r.name for r in resp.resources]}")
+ for resource in resp.resources:
+ uri = resource.uri.encoded_string()
+ self.resource_to_session[uri] = server_name
+ user_shortcut = uri.split('//')[-1]
+ self.static_resources[user_shortcut] = uri
+ if resource.name and resource.name != user_shortcut:
+ self.static_resources[resource.name] = uri
+ except Exception as e:
+ print(f"Failed to load resources from server '{server_name}': {e}")
+
+ # Discover DYNAMIC resource templates
+ try:
+ # The response object for templates has a `.templates` attribute
+ resp = await session.list_resource_templates()
+ if resp.resourceTemplates: print(f" - Found Dynamic Resource Templates: {[t.name for t in resp.resourceTemplates]}")
+ for template in resp.resourceTemplates:
+ uri = template.uriTemplate
+ # The key for the session map MUST be the pattern itself.
+ self.resource_to_session[uri] = server_name
+ if uri not in self.dynamic_resources:
+ self.dynamic_resources.append(uri)
+ except Exception as e:
+ # This is also okay, some servers don't have dynamic resources.
+ print(f"Failed to load dynamic resources from server '{server_name}': {e}")
+
+
+ # List prompts
+ try:
+ prompts_list = await session.list_prompts()
+ print(f"Found {len(prompts_list.prompts)} prompts from server '{server_name}'")
+ for prompt in prompts_list.prompts:
+ self.available_prompts[prompt.name] = prompt
+ self.prompt_to_session[prompt.name] = server_name
+ print(f"Loaded prompt '{prompt.name}' from server '{server_name}'")
+ except Exception as e:
+ print(f"Failed to load prompts from server '{server_name}': {e}")
+
+ except Exception as e:
+ print(f"Failed to load capabilities from server '{server_name}': {e}")
+
+ async def _rebuild_mcp_tool_schemas(self):
+ """Rebuilds the list of tools from connected MCP servers in a LangChain-compatible format."""
+ self.available_tools = []
+ for _, (_, tool_object) in self.tool_to_session.items():
+ # LangChain's .bind_tools can often work directly with this MCP schema
+ tool_schema = {
+ "name": tool_object.name,
+ "description": tool_object.description,
+ "args_schema": self.create_pydantic_model_from_schema(tool_object.name, tool_object.inputSchema)
+ }
+ self.available_tools.append(tool_schema)
+ print(f"Available tools after rebuild: {len(self.available_tools)}")
+
+ def create_pydantic_model_from_schema(self, name: str, schema: dict) -> Type[BaseModel]:
+ """Dynamically creates a Pydantic model from a JSON schema for LangChain tool binding."""
+ fields = {}
+ if schema and 'properties' in schema:
+ for prop_name, prop_details in schema['properties'].items():
+ field_type = str # Default to string
+ # A more robust implementation would map JSON schema types to Python types
+ if prop_details.get('type') == 'integer': field_type = int
+ elif prop_details.get('type') == 'number': field_type = float
+ elif prop_details.get('type') == 'boolean': field_type = bool
+
+ fields[prop_name] = (field_type, Field(..., description=prop_details.get('description')))
+
+ return create_model(name, **fields) # type: ignore
+
+ async def execute_mcp_tool(self, tool_name: str, tool_args: Dict) -> str:
+ try:
+ session, _ = self.tool_to_session[tool_name]
+ result = await session.call_tool(tool_name, arguments=tool_args)
+ if not result.content: return "Tool executed successfully."
+
+ # Handle different content types properly
+ if isinstance(result.content, list):
+ text_parts = []
+ for item in result.content:
+ # Check if item has a text attribute
+ if hasattr(item, 'text'):
+ text_parts.append(str(item.text))
+ else:
+ # Handle other content types
+ text_parts.append(str(item))
+ return " | ".join(text_parts)
+ else:
+ return str(result.content)
+ except Exception as e:
+ # Check if it's a closed resource error
+ if "ClosedResourceError" in str(type(e)) or "closed" in str(e).lower():
+ raise Exception("MCP session is closed. Please try again.") from e
+ else:
+ raise
+
+ async def invoke(self, message_history: List[Dict]) -> Tuple[str, List[Dict], List[Dict]]:
+ """
+ Main entry point. Now returns a tuple of:
+ (final_text_response, tool_calls_trace, new_full_history)
+ """
+ max_retries = 3
+ for attempt in range(max_retries):
+ try:
+ langchain_messages = self._convert_dict_to_langchain_messages(message_history)
+
+ # Separate the final text response from the tool trace
+ final_text_response = ""
+ tool_calls_trace = []
+
+ max_iterations = 10
+ tool_execution_failed = False
+ for iteration in range(max_iterations):
+ await self._rebuild_mcp_tool_schemas()
+ model_with_tools = self.langchain_model.bind_tools(self.available_tools)
+ response_message: AIMessage = await model_with_tools.ainvoke(langchain_messages)
+ langchain_messages.append(response_message)
+
+ # Capture the final text response from the last message
+ if response_message.content:
+ final_text_response = response_message.content
+
+ if not response_message.tool_calls:
+ break
+
+ for tool_call in response_message.tool_calls:
+ tool_name = tool_call['name']
+ tool_args = tool_call['args']
+
+ try:
+ result_content = await self.execute_mcp_tool(tool_name, tool_args)
+ tool_calls_trace.append({
+ "name": tool_name,
+ "args": tool_args,
+ "result": result_content
+ })
+ except Exception as e:
+ if "MCP session is closed" in str(e) and attempt < max_retries - 1:
+ print(f"MCP session closed, reinitializing (attempt {attempt + 1})")
+ await self.cleanup(); await self.connect_to_servers()
+ await asyncio.sleep(0.1); tool_execution_failed = True; break
+ else:
+ result_content = f"Error executing tool {tool_name}: {e}"
+ tool_calls_trace.append({
+ "name": tool_name,
+ "args": tool_args,
+ "error": result_content
+ })
+
+ langchain_messages.append(ToolMessage(content=result_content, tool_call_id=tool_call['id']))
+
+ if tool_execution_failed: break
+
+ if tool_execution_failed and attempt < max_retries - 1: continue
+
+ final_history_dict = self._convert_langchain_messages_to_dict(langchain_messages)
+
+ return final_text_response, tool_calls_trace, final_history_dict
+
+ except RuntimeError as e:
+ if "Event loop is closed" in str(e) and attempt < max_retries - 1:
+ print(f"Event loop closed, reinitializing model (attempt {attempt + 1})")
+ self.langchain_model = self._create_langchain_model(**self.model_settings)
+ await asyncio.sleep(0.1); continue
+ else: raise Exception("Event loop closed. Please try again.") from e
+ except Exception as e:
+ if attempt >= max_retries - 1: raise
+ print(f"Invoke attempt {attempt + 1} failed, retrying: {e}")
+ await asyncio.sleep(0.1)
+
+ raise Exception("Failed to invoke MCP client after all retries")
+
+ async def cleanup(self):
+ """Clean up all resources properly."""
+ try:
+ # Close all sessions using the exit stack to avoid context issues
+ await self.exit_stack.aclose()
+ except Exception as e:
+ # Suppress errors related to async context management as they don't affect functionality
+ if "cancel scope" not in str(e).lower() and "asyncio" not in str(e).lower():
+ print(f"Error during cleanup: {e}")
+
+ try:
+ # Clear sessions
+ self.sessions.clear()
+
+ # Clear other data structures
+ self.tool_to_session.clear()
+ self.available_prompts.clear()
+ self.static_resources.clear()
+ self.dynamic_resources.clear()
+ self.resource_to_session.clear()
+ self.prompt_to_session.clear()
+ self.available_tools.clear()
+
+ # Recreate the exit stack for future use
+ self.exit_stack = AsyncExitStack()
+ except Exception as e:
+ print(f"Error during cleanup: {e}")
diff --git a/src/client/mcp/frontend.py b/src/client/mcp/frontend.py
new file mode 100644
index 00000000..2c645129
--- /dev/null
+++ b/src/client/mcp/frontend.py
@@ -0,0 +1,60 @@
+import streamlit as st
+import os
+import requests
+import json
+
+def set_page():
+ st.set_page_config(
+ page_title="MCP Universal Chatbot",
+ page_icon="🤖",
+ layout="wide"
+ )
+
+def get_fastapi_base_url():
+ return os.getenv("FASTAPI_BASE_URL", "http://127.0.0.1:8000")
+
+
+
+def get_server_files():
+ files = ["server/mcp/server_config.json"]
+ try:
+ with open("server/mcp/server_config.json", "r") as f: config = json.load(f)
+ for server in config.get("mcpServers", {}).values():
+ script_path = server.get("args", [None])[0]
+ if script_path and os.path.exists(script_path): files.append(script_path)
+ except FileNotFoundError: st.sidebar.error("server_config.json not found!")
+ return list(set(files))
+
+def display_ide_tab():
+ st.header("🔧 Integrated MCP Server IDE")
+ st.info("Edit your server configuration or scripts. Restart the launcher for changes to take effect.")
+ server_files = get_server_files()
+ selected_file = st.selectbox("Select a file to edit", options=server_files)
+ if selected_file:
+ with open(selected_file, "r") as f: file_content = f.read()
+ from streamlit_ace import st_ace
+ new_content = st_ace(value=file_content, language="python" if selected_file.endswith(".py") else "json", theme="monokai", keybinding="vscode", height=500, auto_update=True)
+ if st.button("Save Changes"):
+ with open(selected_file, "w") as f: f.write(new_content)
+ st.success(f"Successfully saved {selected_file}!")
+
+def display_commands_tab(tools, resources, prompts):
+ st.header("📖 Discovered MCP Commands")
+ st.info("These commands were discovered from the MCP backend.")
+
+ if tools:
+ with st.expander("🛠️ Available Tools (Used automatically by the AI)", expanded=True):
+ # Extract just the tool names from the tools response
+ if "tools" in tools and isinstance(tools["tools"], list):
+ tool_names = [tool.get("name", tool) if isinstance(tool, dict) else tool for tool in tools["tools"]]
+ st.write(tool_names)
+ else:
+ st.json(tools)
+
+ if resources:
+ with st.expander("📦 Available Resources (Use with `@` or just ``)"):
+ st.json(resources)
+
+ if prompts:
+ with st.expander("📝 Available Prompts (Use with `/prompt ` or select in chat)"):
+ st.json(prompts)
diff --git a/src/client/utils/st_common.py b/src/client/utils/st_common.py
index 675248ef..218e313d 100644
--- a/src/client/utils/st_common.py
+++ b/src/client/utils/st_common.py
@@ -15,6 +15,12 @@
from common import logging_config, help_text
from common.schema import PromptPromptType, PromptNameType, SelectAISettings
+# Import the MCP initialization function
+try:
+ from launch_server import initialize_mcp_engine_with_model
+except ImportError:
+ initialize_mcp_engine_with_model = None
+
logger = logging_config.logging.getLogger("client.utils.st_common")
@@ -27,9 +33,11 @@ def clear_state_key(state_key: str) -> None:
logger.debug("State cleared: %s", state_key)
-def state_configs_lookup(state_configs_name: str, key: str) -> dict[str, dict[str, Any]]:
+def state_configs_lookup(state_configs_name: str, key: str, section: str = None) -> dict[str, dict[str, Any]]:
"""Convert state. into a lookup based on key"""
configs = getattr(state, state_configs_name)
+ if section:
+ configs = configs.get(section, [])
return {config[key]: config for config in configs if key in config}
@@ -164,6 +172,8 @@ def ll_sidebar() -> None:
selected_model = state.client_settings["ll_model"]["model"]
ll_idx = list(ll_models_enabled.keys()).index(selected_model)
if not state.client_settings["selectai"]["enabled"]:
+ # Store the previous model to detect changes
+ previous_model = selected_model
selected_model = st.sidebar.selectbox(
"Chat model:",
options=list(ll_models_enabled.keys()),
@@ -173,6 +183,18 @@ def ll_sidebar() -> None:
disabled=state.client_settings["selectai"]["enabled"],
)
+ # If the model has changed, reinitialize the MCP engine
+ if selected_model != previous_model and initialize_mcp_engine_with_model:
+ try:
+ # Instead of creating a new event loop, we'll set a flag to indicate
+ # that the MCP engine needs to be reinitialized
+ state.mcp_needs_reinit = selected_model
+ logger.info("MCP engine marked for reinitialization with model: %s", selected_model)
+ except Exception as ex:
+ logger.error(
+ "Failed to mark MCP engine for reinitialization with model %s: %s", selected_model, str(ex)
+ )
+
# Temperature
temperature = ll_models_enabled[selected_model]["temperature"]
user_temperature = state.client_settings["ll_model"]["temperature"]
diff --git a/src/client/utils/st_footer.py b/src/client/utils/st_footer.py
index b8d5b643..8314171e 100644
--- a/src/client/utils/st_footer.py
+++ b/src/client/utils/st_footer.py
@@ -65,25 +65,7 @@ def _inject_footer(selector, insertion_method, footer_html, cleanup_styles=True)
"""
components.html(js_code, height=0)
-
-# --- FUNCTION 1: The Cleanup Crew ---
-def remove_footer():
- """
- Injects simple JavaScript to find and remove any existing footer.
- This MUST be called at the TOP of every page in your app.
- """
- js_code = """
-
- """
- components.html(js_code, height=0)
-
-
-# --- FUNCTION 2: The Chat Page Footer ---
+# --- The Chat Page Footer ---
def render_chat_footer():
"""
Standardized footer for chat pages.
@@ -97,22 +79,3 @@ def render_chat_footer():
_inject_footer(
selector='[data-testid="stBottomBlockContainer"]', insertion_method="afterend", footer_html=footer_html
)
-
-
-# --- FUNCTION 3: The Models Page Footer ---
-def render_models_footer():
- """
- Standardized footer for models pages.
- """
- footer_html = f"""
- {FOOTER_STYLE}
-
- """
- _inject_footer(
- selector='[data-testid="stAppIframeResizerAnchor"]',
- insertion_method="beforebegin",
- footer_html=footer_html,
- cleanup_styles=False,
- )
diff --git a/src/common/schema.py b/src/common/schema.py
index 5cc42fcf..164dc258 100644
--- a/src/common/schema.py
+++ b/src/common/schema.py
@@ -108,7 +108,9 @@ class DatabaseVectorStorage(BaseModel):
"""Database Vector Storage Tables"""
vector_store: Optional[str] = Field(
- default=None, description="Vector Store Table Name (auto-generated, do not set)", readOnly=True
+ default=None,
+ description="Vector Store Table Name (auto-generated, do not set)",
+ json_schema_extra={"readOnly": True},
)
alias: Optional[str] = Field(default=None, description="Identifiable Alias")
model: Optional[str] = Field(default=None, description="Embedding Model")
@@ -121,8 +123,8 @@ class DatabaseVectorStorage(BaseModel):
class DatabaseSelectAIObjects(BaseModel):
"""Database SelectAI Objects"""
- owner: Optional[str] = Field(default=None, description="Object Owner", readOnly=True)
- name: Optional[str] = Field(default=None, description="Object Name", readOnly=True)
+ owner: Optional[str] = Field(default=None, description="Object Owner", json_schema_extra={"readOnly": True})
+ name: Optional[str] = Field(default=None, description="Object Name", json_schema_extra={"readOnly": True})
enabled: bool = Field(default=False, description="SelectAI Enabled")
@@ -144,12 +146,14 @@ class Database(DatabaseAuth):
"""Database Object"""
name: str = Field(default="DEFAULT", description="Name of Database (Alias)")
- connected: bool = Field(default=False, description="Connection Established", readOnly=True)
+ connected: bool = Field(default=False, description="Connection Established", json_schema_extra={"readOnly": True})
vector_stores: Optional[list[DatabaseVectorStorage]] = Field(
- default=[], description="Vector Storage (read-only)", readOnly=True
+ default=[], description="Vector Storage (read-only)", json_schema_extra={"readOnly": True}
)
selectai: bool = Field(default=False, description="SelectAI Possible")
- selectai_profiles: Optional[list] = Field(default=[], description="SelectAI Profiles (read-only)", readOnly=True)
+ selectai_profiles: Optional[list] = Field(
+ default=[], description="SelectAI Profiles (read-only)", json_schema_extra={"readOnly": True}
+ )
# Do not expose the connection to the endpoint
_connection: oracledb.Connection = PrivateAttr(default=None)
@@ -163,6 +167,40 @@ def set_connection(self, connection: oracledb.Connection) -> None:
self._connection = connection
+#####################################################
+# MCP
+#####################################################
+class MCPModelConfig(BaseModel):
+ """MCP Model Configuration"""
+
+ model_id: str = Field(..., description="Model identifier")
+ service_type: Literal["ollama", "openai"] = Field(..., description="AI service type")
+ base_url: str = Field(default="http://localhost:11434", description="Base URL for API")
+ api_key: Optional[str] = Field(default=None, description="API key", json_schema_extra={"sensitive": True})
+ enabled: bool = Field(default=True, description="Model availability status")
+ streaming: bool = Field(default=False, description="Enable streaming responses")
+ temperature: float = Field(default=1.0, description="Model temperature")
+ max_tokens: int = Field(default=2048, description="Maximum tokens per response")
+
+
+class MCPToolConfig(BaseModel):
+ """MCP Tool Configuration"""
+
+ name: str = Field(..., description="Tool name")
+ description: str = Field(..., description="Tool description")
+ parameters: dict[str, Any] = Field(..., description="Tool parameters")
+ enabled: bool = Field(default=True, description="Tool availability status")
+
+
+class MCPSettings(BaseModel):
+ """MCP Global Settings"""
+
+ models: list[MCPModelConfig] = Field(default_factory=list, description="Available MCP models")
+ tools: list[MCPToolConfig] = Field(default_factory=list, description="Available MCP tools")
+ default_model: Optional[str] = Field(default=None, description="Default model identifier")
+ enabled: bool = Field(default=True, description="Enable or disable MCP functionality")
+
+
#####################################################
# Models
#####################################################
@@ -225,6 +263,16 @@ def check_provider(self):
raise ValueError(f"Provider '{self.provider}' is not valid. Must be one of: {providers}")
return self
+ def check_provider_matches_type(self):
+ """Validate valid API"""
+ providers = get_args(ModelProviders)
+ if not self.provider or self.provider == "unset":
+ return self
+
+ if self.provider not in providers:
+ raise ValueError(f"Provider '{self.provider}' is not valid. Must be one of: {providers}")
+ return self
+
#####################################################
# Oracle Cloud Infrastructure
@@ -239,7 +287,9 @@ class OracleCloudSettings(BaseModel):
"""Store Oracle Cloud Infrastructure Settings"""
auth_profile: str = Field(default="DEFAULT", description="Config File Profile")
- namespace: Optional[str] = Field(default=None, description="Object Store Namespace", readOnly=True)
+ namespace: Optional[str] = Field(
+ default=None, description="Object Store Namespace", json_schema_extra={"readOnly": True}
+ )
user: Optional[str] = Field(
default=None,
description="Optional if using Auth Token",
@@ -380,6 +430,7 @@ class Configuration(BaseModel):
model_configs: Optional[list[Model]] = None
oci_configs: Optional[list[OracleCloudSettings]] = None
prompt_configs: Optional[list[Prompt]] = None
+ mcp_configs: Optional[list[MCPModelConfig]] = Field(default=None, description="List of MCP configurations")
def model_dump_public(self, incl_sensitive: bool = False, incl_readonly: bool = False) -> dict:
"""Remove marked fields for FastAPI Response"""
@@ -502,3 +553,6 @@ class EvaluationReport(Evaluation):
TestSetsIdType = TestSets.__annotations__["tid"]
TestSetsNameType = TestSets.__annotations__["name"]
TestSetDateType = TestSets.__annotations__["created"]
+MCPModelIdType = MCPModelConfig.__annotations__["model_id"]
+MCPServiceType = MCPModelConfig.__annotations__["service_type"]
+MCPToolNameType = MCPToolConfig.__annotations__["name"]
diff --git a/src/hello_world.py b/src/hello_world.py
new file mode 100644
index 00000000..2028682b
--- /dev/null
+++ b/src/hello_world.py
@@ -0,0 +1,81 @@
+import asyncio
+from mcp import ClientSession
+from mcp.client.streamable_http import streamablehttp_client
+
+from langgraph.prebuilt import create_react_agent
+from langchain_mcp_adapters.client import MultiServerMCPClient
+from langchain_mcp_adapters.tools import load_mcp_tools
+
+
+client = MultiServerMCPClient(
+ {
+ "optimizer": {
+ "transport": "streamable_http",
+ "url": "http://localhost:8000/mcp/",
+ "headers": {"Authorization": "Bearer demo_api_key"},
+ }
+ }
+)
+async def call_tool(name: str):
+ tools = await client.get_tools()
+ agent = create_react_agent("openai:gpt-4o-mini", tools)
+ math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"})
+ print(math_response)
+
+# async def call_tool(name: str):
+# async with client.session("optimizer") as session:
+ # tools = await load_mcp_tools(session)
+ # agent = create_react_agent("openai:gpt-4o-mini", tools)
+ # # math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"})
+ # # weather_response = await agent.ainvoke({"messages": "what is the weather in nyc?"})
+ # database_response = await agent.ainvoke({"messages": "connect to OPTIMIZER_DEFAULT"})
+ # database_response = await agent.ainvoke({"messages": "show me a list of table names"})
+ # print(database_response)
+ # # print(weather_response)
+
+asyncio.run(call_tool("Ford"))
+
+# async def call_tool(name: str):
+# async with streamablehttp_client(config) as (read, write, _):
+# async with ClientSession(read, write) as session:
+# # Initialize the connection
+# await session.initialize()
+
+# # Get tools
+# tools = await load_mcp_tools(session)
+# agent = create_react_agent("openai:gpt-4o-mini", tools)
+# math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"})
+# print(math_response)
+
+
+# asyncio.run(call_tool("Ford"))
+
+# client = Client(config)
+# async def call_tool(name: str):
+# async with client:
+# print(f"Connected: {client.is_connected()}")
+# tools = await client.load_mcp_tools(client)
+# # agent = create_react_agent("openai:gpt-4o-mini", tools)
+# # math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"})
+# # print(math_response)
+# # result = await client.call_tool("optimizer_greet", {"name": name})
+# # print(result)
+# # result = await client.call_tool("optimizer_multiply", {"a": 5, "b": 3})
+# # print(result)
+
+
+# from mcp import ClientSession
+# from mcp.client.streamable_http import streamablehttp_client
+
+# from langgraph.prebuilt import create_react_agent
+# from langchain_mcp_adapters.tools import load_mcp_tools
+
+# async with streamablehttp_client("http://localhost:8000/mcp/") as (read, write, _):
+# async with ClientSession(read, write) as session:
+# # Initialize the connection
+# await session.initialize()
+
+# # Get tools
+# tools = await load_mcp_tools(session)
+# agent = create_react_agent("openai:gpt-4.1", tools)
+# math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"})
diff --git a/src/launch_client.py b/src/launch_client.py
index d4051512..60af0451 100644
--- a/src/launch_client.py
+++ b/src/launch_client.py
@@ -91,7 +91,7 @@ def main() -> None:
}
.stAppHeader img[alt="Logo"] {
width: 50%;
- }
+ }
""",
)
@@ -131,6 +131,7 @@ def main() -> None:
state.disabled["model_cfg"] = os.environ.get("DISABLE_MODEL_CFG", "false").lower() == "true"
state.disabled["oci_cfg"] = os.environ.get("DISABLE_OCI_CFG", "false").lower() == "true"
state.disabled["settings"] = os.environ.get("DISABLE_SETTINGS", "false").lower() == "true"
+ state.disabled["mcp_cfg"] = os.environ.get("DISABLE_MCP_CFG", "false").lower() == "true"
# Left Hand Side - Navigation
chatbot = st.Page("client/content/chatbot.py", title="ChatBot", icon="💬", default=True)
diff --git a/src/launch_server.py b/src/launch_server.py
index fde83bda..a14d8ca4 100644
--- a/src/launch_server.py
+++ b/src/launch_server.py
@@ -2,13 +2,17 @@
Copyright (c) 2024, 2025, Oracle and/or its affiliates.
Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
"""
-# spell-checker:ignore fastapi laddr checkpointer langgraph litellm
-# spell-checker:ignore noauth apiserver configfile selectai giskard ollama llms
-# pylint: disable=redefined-outer-name,wrong-import-position
+# spell-checker:ignore configfile fastmcp noauth selectai getpid procs litellm giskard ollama
+# spell-checker:ignore dotenv apiserver laddr
-import os
+# Patch litellm for Giskard/Ollama issue
+import server.patches.litellm_patch # pylint: disable=unused-import, wrong-import-order
+# Set OS Environment before importing other modules
# Set OS Environment (Don't move their position to reflect on imports)
+# pylint: disable=wrong-import-position
+import os
+
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
os.environ["GSK_DISABLE_SENTRY"] = "true"
os.environ["GSK_DISABLE_ANALYTICS"] = "true"
@@ -16,37 +20,49 @@
app_home = os.path.dirname(os.path.abspath(__file__))
if "TNS_ADMIN" not in os.environ:
os.environ["TNS_ADMIN"] = os.path.join(app_home, "tns_admin")
-
-# Patch litellm for Giskard/Ollama issue
-import server.patches.litellm_patch # pylint: disable=unused-import
+# pylint: enable=wrong-import-position
import argparse
-import queue
+import asyncio
+from contextlib import asynccontextmanager
+from pathlib import Path
import secrets
import socket
import subprocess
-import threading
+import sys
from typing import Annotated
from pathlib import Path
import uvicorn
-import psutil
+from fastapi import FastAPI, APIRouter, Depends, HTTPException, status
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from fastmcp import FastMCP, settings
+from fastmcp.server.auth import StaticTokenVerifier
+from langgraph.checkpoint.memory import InMemorySaver
-from fastapi import FastAPI, HTTPException, Depends, status, APIRouter
+
+# Third Party
+import psutil
+import uvicorn
+from fastapi import FastAPI, APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from fastmcp import FastMCP, settings
+from fastmcp.server.auth import StaticTokenVerifier
+
+# Configuration
+from server.bootstrap import configfile # pylint: disable=ungrouped-imports
# Logging
from common import logging_config
from common._version import __version__
-# Configuration
-from server.bootstrap import configfile
-
logger = logging_config.logging.getLogger("launch_server")
+# Establish LangGraph Short-Term Memory (thread-level persistence)
+graph_memory = InMemorySaver()
##########################################
-# Process Control
+# Client Process Control
##########################################
def start_server(port: int = 8000, logfile: bool = False) -> int:
"""Start the uvicorn server for FastAPI"""
@@ -71,54 +87,33 @@ def get_pid_using_port(port: int) -> int:
continue
return None
- def start_subprocess(port: int, logfile: bool) -> subprocess.Popen:
- """Start the uvicorn server as a subprocess."""
- logger.info("API server starting on port: %i", port)
- log_file = open(f"apiserver_{port}.log", "a", encoding="utf-8") if logfile else None
- stdout = stderr = log_file if logfile else subprocess.PIPE
- process = subprocess.Popen(
- [
- "uvicorn",
- "launch_server:create_app",
- "--factory",
- "--host",
- "0.0.0.0",
- "--port",
- str(port),
- ],
- stdout=stdout,
- stderr=stderr,
- )
- logger.info("API server started on Port: %i; PID: %i", port, process.pid)
- return process
-
port = port or find_available_port()
- existing_pid = get_pid_using_port(port)
- if existing_pid:
+ if existing_pid := get_pid_using_port(port):
logger.info("API server already running on port: %i (PID: %i)", port, existing_pid)
return existing_pid
- popen_queue = queue.Queue()
- thread = threading.Thread(
- target=lambda: popen_queue.put(start_subprocess(port, logfile)),
- daemon=True,
- )
- thread.start()
+ client_args = [sys.executable, __file__, "--port", str(port)]
+ if logfile:
+ log_file = open(f"apiserver_{port}.log", "a", encoding="utf-8") # pylint: disable=consider-using-with
+ stdout = stderr = log_file
+ else:
+ stdout = stderr = subprocess.PIPE
- return popen_queue.get().pid
+ process = subprocess.Popen(client_args, stdout=stdout, stderr=stderr) # pylint: disable=consider-using-with
+ logger.info("Server started on port %i with PID %i", port, process.pid)
+ return process.pid
def stop_server(pid: int) -> None:
- """Stop the uvicorn server for FastAPI."""
+ """Stop the uvicorn server for FastAPI when started via the client"""
try:
proc = psutil.Process(pid)
proc.terminate()
proc.wait()
+ logger.info("API server stopped.")
except (psutil.NoSuchProcess, psutil.AccessDenied) as ex:
logger.error("Failed to terminate process with PID: %i - %s", pid, ex)
- logger.info("API server stopped.")
-
##########################################
# Server App and API Key
@@ -136,68 +131,124 @@ def get_api_key() -> str:
return os.getenv("API_SERVER_KEY")
-def verify_key(
+def fastapi_verify_key(
http_auth: Annotated[
HTTPAuthorizationCredentials,
Depends(HTTPBearer(description="Please provide API_SERVER_KEY.")),
],
) -> None:
- """Verify that the provided API key is correct."""
+ """FastAPI: Verify that the provided API key is correct."""
if http_auth.credentials != get_api_key():
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
-def register_endpoints(noauth: APIRouter, auth: APIRouter):
+##########################################
+# Endpoint Registration
+##########################################
+async def register_endpoints(mcp: FastMCP, auth: APIRouter, noauth: APIRouter):
"""Register API Endpoints - Imports to avoid bootstrapping before config file read
New endpoints need to be registered in server.api.v1.__init__.py
"""
- import server.api.v1 as api_v1 # pylint: disable=import-outside-toplevel
+ logger.debug("Starting Endpoint Registration")
+ # pylint: disable=import-outside-toplevel
+ import server.api.v1 as api_v1
+ from server.mcp import register_all_mcp
# No-Authentication (probes only)
noauth.include_router(api_v1.probes.noauth, prefix="/v1", tags=["Probes"])
# Authenticated
auth.include_router(api_v1.chat.auth, prefix="/v1/chat", tags=["Chatbot"])
- auth.include_router(api_v1.databases.auth, prefix="/v1/databases", tags=["Config - Databases"])
auth.include_router(api_v1.embed.auth, prefix="/v1/embed", tags=["Embeddings"])
- auth.include_router(api_v1.models.auth, prefix="/v1/models", tags=["Config - Models"])
- auth.include_router(api_v1.oci.auth, prefix="/v1/oci", tags=["Config - Oracle Cloud Infrastructure"])
- auth.include_router(api_v1.prompts.auth, prefix="/v1/prompts", tags=["Tools - Prompts"])
auth.include_router(api_v1.selectai.auth, prefix="/v1/selectai", tags=["SelectAI"])
- auth.include_router(api_v1.settings.auth, prefix="/v1/settings", tags=["Tools - Settings"])
+ auth.include_router(api_v1.prompts.auth, prefix="/v1/prompts", tags=["Tools - Prompts"])
auth.include_router(api_v1.testbed.auth, prefix="/v1/testbed", tags=["Tools - Testbed"])
+ auth.include_router(api_v1.settings.auth, prefix="/v1/settings", tags=["Config - Settings"])
+ auth.include_router(api_v1.databases.auth, prefix="/v1/databases", tags=["Config - Databases"])
+ auth.include_router(api_v1.models.auth, prefix="/v1/models", tags=["Config - Models"])
+ auth.include_router(api_v1.oci.auth, prefix="/v1/oci", tags=["Config - Oracle Cloud Infrastructure"])
+ auth.include_router(api_v1.mcp.auth, prefix="/v1/mcp", tags=["Config - MCP Servers"])
+
+ # Auto-discover all MCP tools and register HTTP + MCP endpoints
+ mcp_router = APIRouter(prefix="/mcp", tags=["MCP Tools"])
+ await register_all_mcp(mcp, auth)
+ auth.include_router(mcp_router)
+ logger.debug("Finished Endpoint Registration")
#############################################################################
# APP FACTORY
#############################################################################
-def create_app(config: str = None) -> FastAPI:
- """Create and configure the FastAPI app."""
+async def create_app(config: str = "") -> FastAPI:
+ """FastAPI Application Factory"""
+
if not config:
config = configfile.config_file_path()
config_file = Path(os.getenv("CONFIG_FILE", config))
configfile.ConfigStore.load_from_file(config_file)
- app = FastAPI(
+ # FastMCP Server
+ fastmcp_verifier = StaticTokenVerifier(
+ tokens={get_api_key(): {"client_id": "optimizer", "scopes": ["read", "write"]}}
+ )
+ settings.stateless_http = True
+ fastmcp_app = FastMCP(
+ name="Oracle AI Optimizer and Toolkit MCP Server",
+ version=__version__,
+ auth=fastmcp_verifier,
+ include_fastmcp_meta=False,
+ )
+ fastmcp_engine = fastmcp_app.http_app(path="/")
+
+ @asynccontextmanager
+ async def combined_lifespan(fastapi_app: FastAPI):
+ """Ensures all MCP Servers are cleaned up"""
+ async with fastmcp_engine.lifespan(fastapi_app):
+ yield
+ # Shutdown cleanup
+ logger.info("Cleaning up leftover processes...")
+ parent = psutil.Process(os.getpid())
+ children = parent.children(recursive=True)
+ for p in children:
+ try:
+ p.terminate()
+ except psutil.NoSuchProcess:
+ continue
+ # Wait synchronously, outside the event loop
+ _, still_alive = psutil.wait_procs(children, timeout=3)
+ for p in still_alive:
+ try:
+ p.kill()
+ except psutil.NoSuchProcess:
+ continue
+
+ # FastAPI Server
+ fastapi_app = FastAPI(
title="Oracle AI Optimizer and Toolkit",
version=__version__,
docs_url="/v1/docs",
openapi_url="/v1/openapi.json",
+ lifespan=combined_lifespan,
license_info={
"name": "Universal Permissive License",
"url": "http://oss.oracle.com/licenses/upl",
},
)
+ # Store MCP in the app state
+ fastapi_app.state.fastmcp_app = fastmcp_app
+ # Register MCP Server into FastAPI
+ fastapi_app.mount("/mcp", fastmcp_engine)
+ # Setup Routes and Register non-MCP endpoints
noauth = APIRouter()
- auth = APIRouter(dependencies=[Depends(verify_key)])
+ auth = APIRouter(dependencies=[Depends(fastapi_verify_key)])
- # Register Endpoints
- register_endpoints(noauth, auth)
- app.include_router(noauth)
- app.include_router(auth)
+ # Register the endpoints
+ await register_endpoints(fastmcp_app, auth, noauth)
+ fastapi_app.include_router(noauth)
+ fastapi_app.include_router(auth)
- return app
+ return fastapi_app
if __name__ == "__main__":
@@ -209,10 +260,23 @@ def create_app(config: str = None) -> FastAPI:
default=configfile.config_file_path(),
help="Full path to configuration file (JSON)",
)
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8000,
+ help="Port to start server",
+ )
args = parser.parse_args()
PORT = int(os.getenv("API_SERVER_PORT", "8000"))
logger.info("API Server Using port: %i", PORT)
- app = create_app(args.config)
- uvicorn.run(app, host="0.0.0.0", port=PORT, log_config=logging_config.LOGGING_CONFIG)
+ # Sync entrypoint, but calls async factory before running Uvicorn
+ app = asyncio.run(create_app(args.config))
+ uvicorn.run(
+ app,
+ host="0.0.0.0",
+ port=PORT,
+ timeout_graceful_shutdown=5,
+ log_config=logging_config.LOGGING_CONFIG,
+ )
diff --git a/src/pyproject.toml b/src/pyproject.toml
index fce2dc60..87a33b3c 100644
--- a/src/pyproject.toml
+++ b/src/pyproject.toml
@@ -46,7 +46,7 @@ server = [
"langchain-perplexity==0.1.2",
"langchain-xai==0.2.5",
"langgraph==0.6.6",
- "litellm==1.76.1",
+ "litellm==1.76.1",
"llama-index==0.13.3",
"lxml==6.0.0",
"matplotlib==3.10.6",
diff --git a/src/server/Dockerfile b/src/server/Dockerfile
index 70992359..c9113274 100644
--- a/src/server/Dockerfile
+++ b/src/server/Dockerfile
@@ -12,7 +12,7 @@ ENV RUNUSER=oracleai
ENV PATH=/opt/.venv/bin:$PATH
RUN microdnf --nodocs -y update && \
- microdnf --nodocs -y install python3.11 python3.11-pip && \
+ microdnf --nodocs -y install python3.11 python3.11-pip sqlcl && \
microdnf clean all && \
python3.11 -m venv --symlinks --upgrade-deps /opt/.venv && \
groupadd $RUNUSER && \
diff --git a/src/server/agents/tools/selectai.py b/src/server/agents/tools/selectai.py
index fb0f40ac..1b24a943 100644
--- a/src/server/agents/tools/selectai.py
+++ b/src/server/agents/tools/selectai.py
@@ -10,9 +10,10 @@
from langchain_core.tools import BaseTool, tool
from langchain_core.runnables import RunnableConfig
-from common import logging_config
from server.api.utils.databases import execute_sql
+from common import logging_config
+
logger = logging_config.logging.getLogger("server.tools.selectai_executor")
# ------------------------------------------------------------------------------
diff --git a/src/server/api/core/bootstrap.py b/src/server/api/core/bootstrap.py
index fd970758..5db865e0 100644
--- a/src/server/api/core/bootstrap.py
+++ b/src/server/api/core/bootstrap.py
@@ -4,7 +4,7 @@
"""
# spell-checker:ignore genai
-from server.bootstrap import databases, models, oci, prompts, settings
+from server.bootstrap import databases, models, oci, prompts, settings, mcp
from common import logging_config
logger = logging_config.logging.getLogger("api.core.bootstrap")
@@ -14,3 +14,4 @@
OCI_OBJECTS = oci.main()
PROMPT_OBJECTS = prompts.main()
SETTINGS_OBJECTS = settings.main()
+MCP_OBJECTS = mcp.main()
diff --git a/src/server/api/core/mcp.py b/src/server/api/core/mcp.py
new file mode 100644
index 00000000..a0a5d128
--- /dev/null
+++ b/src/server/api/core/mcp.py
@@ -0,0 +1,59 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+
+# spell-checker:ignore streamable
+import os
+
+# from langchain_mcp_adapters.client import MultiServerMCPClient
+# from typing import Optional, List, Dict, Any
+# from common.schema import MCPModelConfig, MCPToolConfig, MCPSettings
+# from server.bootstrap import mcp as mcp_bootstrap
+from common import logging_config
+
+logger = logging_config.logging.getLogger("api.core.mcp")
+
+
+def get_client(server: str = "http://127.0.0.1", port: int = 8000, client: str = None) -> dict:
+ """Get the MCP Client Configuration"""
+ mcp_client = {
+ "mcpServers": {
+ "optimizer": {
+ "type": "streamableHttp",
+ "transport": "streamable_http",
+ "url": f"{server}:{port}/mcp/",
+ "headers": {"Authorization": f"Bearer {os.getenv('API_SERVER_KEY')}"},
+ }
+ }
+ }
+ if client == "langgraph":
+ del mcp_client["mcpServers"]["optimizer"]["type"]
+
+
+# def get_mcp_model(model_id: str) -> Optional[MCPModelConfig]:
+# """Get MCP model configuration by ID"""
+# for model in mcp_bootstrap.MCP_MODELS:
+# if model.model_id == model_id:
+# return model
+# return None
+
+
+# def get_mcp_tool(tool_name: str) -> Optional[MCPToolConfig]:
+# """Get MCP tool configuration by name"""
+# for tool in mcp_bootstrap.MCP_TOOLS:
+# if tool.name == tool_name:
+# return tool
+# return None
+
+
+# def update_mcp_settings(settings: Dict[str, Any]) -> MCPSettings:
+# """Update MCP settings"""
+# if not mcp_bootstrap.MCP_SETTINGS:
+# raise ValueError("MCP settings not initialized")
+
+# for key, value in settings.items():
+# if hasattr(mcp_bootstrap.MCP_SETTINGS, key):
+# setattr(mcp_bootstrap.MCP_SETTINGS, key, value)
+
+# return mcp_bootstrap.MCP_SETTINGS
diff --git a/src/server/api/core/models.py b/src/server/api/core/models.py
index a8114224..b86f2a8a 100644
--- a/src/server/api/core/models.py
+++ b/src/server/api/core/models.py
@@ -44,9 +44,13 @@ def get_model(
) -> Union[list[Model], Model, None]:
"""Used in direct call from list_models and agents.models"""
model_objects = bootstrap.MODEL_OBJECTS
-
- logger.debug("%i models are defined", len(model_objects))
-
+ logger.debug(
+ "Filtering %i models for id: %s; type: %s; disabled: %s",
+ len(model_objects),
+ model_id,
+ model_type,
+ include_disabled
+ )
model_filtered = [
model
for model in model_objects
diff --git a/src/server/api/core/settings.py b/src/server/api/core/settings.py
index 0f993721..81de678e 100644
--- a/src/server/api/core/settings.py
+++ b/src/server/api/core/settings.py
@@ -54,11 +54,15 @@ def get_server_config() -> Configuration:
prompt_objects = bootstrap.PROMPT_OBJECTS
prompt_configs = list(prompt_objects)
+ # mcp_objects = bootstrap.MCP_OBJECTS
+ # mcp_configs = list(mcp_objects)
+
full_config = {
"database_configs": database_configs,
"model_configs": model_configs,
"oci_configs": oci_configs,
"prompt_configs": prompt_configs,
+ # "mcp_configs": mcp_configs,
}
return full_config
@@ -92,6 +96,9 @@ def update_server_config(config_data: dict) -> None:
if "prompt_configs" in config_data:
bootstrap.PROMPT_OBJECTS = config.prompt_configs or []
+ if "mcp_configs" in config_data:
+ bootstrap.MCP_OBJECTS = config.mcp_configs or []
+
def load_config_from_json_data(config_data: dict, client: ClientIdType = None) -> None:
"""Shared logic for loading settings from JSON data."""
diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py
index 58b71780..39f1b54d 100644
--- a/src/server/api/utils/chat.py
+++ b/src/server/api/utils/chat.py
@@ -21,8 +21,7 @@
from server.api.core.models import UnknownModelError
-from common import schema
-from common import logging_config
+from common import schema, logging_config
logger = logging_config.logging.getLogger("api.utils.chat")
diff --git a/src/server/api/utils/embed.py b/src/server/api/utils/embed.py
index 173774b9..ef8abf75 100644
--- a/src/server/api/utils/embed.py
+++ b/src/server/api/utils/embed.py
@@ -27,9 +27,7 @@
import server.api.utils.databases as utils_databases
-from common import schema, functions
-
-from common import logging_config
+from common import schema, functions, logging_config
logger = logging_config.logging.getLogger("api.utils.embed")
diff --git a/src/server/api/utils/mcp.py b/src/server/api/utils/mcp.py
new file mode 100644
index 00000000..e70f72e5
--- /dev/null
+++ b/src/server/api/utils/mcp.py
@@ -0,0 +1,140 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+# spell-checker:ignore astream selectai
+
+import os
+import time
+from typing import Literal, AsyncGenerator
+import json
+import oci
+
+from langchain_core.messages import HumanMessage
+from langchain_core.runnables import RunnableConfig
+from langchain_mcp_adapters.client import MultiServerMCPClient
+from langgraph.graph.state import CompiledStateGraph
+
+import server.api.core.settings as core_settings
+import server.api.core.oci as core_oci
+import server.api.core.prompts as core_prompts
+import server.api.utils.models as util_models
+import server.api.utils.databases as util_databases
+import server.api.utils.selectai as util_selectai
+import server.api.core.mcp as core_mcp
+import server.mcp.graph as graph
+
+from common import logging_config, schema
+
+logger = logging_config.logging.getLogger("api.utils.mcp")
+
+def get_client(server: str = "http://127.0.0.1", port: int = 8000) -> dict:
+ """Get the MCP Client Configuration"""
+ mcp_client = {
+ "mcpServers": {
+ "optimizer": {
+ "type": "streamableHttp",
+ "transport": "streamable_http",
+ "url": f"{server}:{port}/mcp/",
+ "headers": {"Authorization": f"Bearer {os.getenv('API_SERVER_KEY')}"},
+ }
+ }
+ }
+
+ return mcp_client
+
+def error_response(call: str, message: str, model: dict) -> dict:
+ """Send the error as a response"""
+ response = message
+ if call != "streams":
+ response = {
+ "id": "error",
+ "choices": [{"message": {"role": "assistant", "content": message}, "index": 0, "finish_reason": "stop"}],
+ "created": int(time.time()),
+ "model": model["model"],
+ "object": "chat.completion",
+ }
+ logger.debug("Returning Error Response: %s", response)
+ return response
+
+
+async def completion_generator(
+ client: schema.ClientIdType, request: schema.ChatRequest, call: Literal["completions", "streams"]
+) -> AsyncGenerator[str, None]:
+ """MCP Completion Requests"""
+ client_settings = core_settings.get_client_settings(client)
+ model = request.model_dump()
+ logger.debug("Settings: %s", client_settings)
+ logger.debug("Request: %s", model)
+
+ # Establish LL Model Params (if the request specs a model, otherwise override from settings)
+ if not model["model"]:
+ model = client_settings.ll_model.model_dump()
+
+ # Get OCI Settings
+ oci_config = core_oci.get_oci(client=client)
+
+ # Setup Language Model
+ ll_model = util_models.get_client(model, oci_config)
+ if not ll_model:
+ yield error_response("I'm unable to initialise the Language Model. Please refresh the application.", model)
+ return
+
+ # Setup MCP and bind tools
+ mcp_client = MultiServerMCPClient({"optimizer": core_mcp.get_client()["mcpServers"]["optimizer"]})
+ tools = await mcp_client.get_tools()
+ ll_model_with_tools = model.bind_tools(tools)
+
+ # Build our Graph
+ graph.set_node("tools_node", ToolNode(tools))
+ agent: CompiledStateGraph = graph.mcp_graph
+ # Setup MCP and bind tools
+ mcp_client = MultiServerMCPClient(
+ {"optimizer": core_mcp.get_client(client="langgraph")["mcpServers"]["optimizer"]}
+ )
+ tools = await mcp_client.get_tools()
+ try:
+ ll_model_with_tools = ll_model.bind_tools(tools)
+ except NotImplementedError as ex:
+ yield error_response(call, str(ex), model)
+ raise
+
+ # Build our Graph
+ agent: CompiledStateGraph = graph.main(tools)
+
+ kwargs = {
+ "input": {"messages": [HumanMessage(content=request.messages[0].content)]},
+ "config": RunnableConfig(
+ configurable={"thread_id": client, "ll_model": ll_model_with_tools, "tools": tools},
+
+ metadata={"use_history": client_settings.ll_model.chat_history},
+ ),
+ }
+
+ yield "End"
+
+
+ try:
+ async for chunk in agent.astream_events(**kwargs, version="v2"):
+ # The below will produce A LOT of output; uncomment when desperate
+ # logger.debug("Streamed Chunk: %s", chunk)
+ if chunk["event"] == "on_chat_model_stream":
+ if "tools_condition" in str(chunk["metadata"]["langgraph_triggers"]):
+ continue # Skip Tool Call messages
+ if "vs_retrieve" in str(chunk["metadata"]["langgraph_node"]):
+ continue # Skip Fake-Tool Call messages
+ content = chunk["data"]["chunk"].content
+ if content != "" and call == "streams":
+ yield content.encode("utf-8")
+ last_response = chunk["data"]
+ except oci.exceptions.ServiceError as ex:
+ error_details = json.loads(ex.message).get("message", "")
+ yield error_response(call, error_details, model)
+ raise
+
+ # Clean Up
+ if call == "streams":
+ yield "[stream_finished]" # This will break the Chatbot loop
+ elif call == "completions":
+ final_response = last_response["output"]["final_response"]
+ yield final_response # This will be captured for ChatResponse
diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py
index ae2dd93d..36bfc586 100644
--- a/src/server/api/utils/models.py
+++ b/src/server/api/utils/models.py
@@ -16,8 +16,7 @@
import server.api.core.models as core_models
from common.functions import is_url_accessible
-from common import schema
-from common import logging_config
+from common import logging_config, schema
logger = logging_config.logging.getLogger("api.utils.models")
@@ -159,19 +158,18 @@ def get_client_embed(model_config: dict, oci_config: schema.OracleCloudSettings)
else:
if provider == "hosted_vllm":
kwargs = {
- "provider": "openai",
- "model": full_model_config["id"],
- "base_url": full_model_config.get("api_base"),
- "check_embedding_ctx_length":False #To avoid Tiktoken pre-transform on not OpenAI provided server
+ "provider": "openai",
+ "model": full_model_config["id"],
+ "base_url": full_model_config.get("api_base"),
+ "check_embedding_ctx_length": False, # To avoid Tiktoken pre-transform on not OpenAI provided server
}
else:
kwargs = {
- "provider": provider,
- "model": full_model_config["id"],
- "base_url": full_model_config.get("api_base"),
+ "provider": provider,
+ "model": full_model_config["id"],
+ "base_url": full_model_config.get("api_base"),
}
-
if full_model_config.get("api_key"): # only add if set
kwargs["api_key"] = full_model_config["api_key"]
client = init_embeddings(**kwargs)
diff --git a/src/server/api/utils/testbed.py b/src/server/api/utils/testbed.py
index 84528ecb..f12782ec 100644
--- a/src/server/api/utils/testbed.py
+++ b/src/server/api/utils/testbed.py
@@ -20,8 +20,8 @@
import server.api.utils.databases as utils_databases
import server.api.utils.models as utils_models
-from common import schema
-from common import logging_config
+
+from common import schema, logging_config
logger = logging_config.logging.getLogger("api.utils.testbed")
diff --git a/src/server/api/v1/__init__.py b/src/server/api/v1/__init__.py
index fcd6743f..f9da75e0 100644
--- a/src/server/api/v1/__init__.py
+++ b/src/server/api/v1/__init__.py
@@ -2,5 +2,6 @@
Copyright (c) 2024, 2025, Oracle and/or its affiliates.
Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
"""
+# spell-checker:ignore selectai
-from . import chat, databases, embed, models, oci, probes, prompts, testbed, settings, selectai
+from . import chat, databases, embed, models, oci, probes, prompts, testbed, settings, selectai, mcp
diff --git a/src/server/api/v1/chat.py b/src/server/api/v1/chat.py
index 0d4379d2..b9b8bbee 100644
--- a/src/server/api/v1/chat.py
+++ b/src/server/api/v1/chat.py
@@ -18,11 +18,10 @@
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.message import REMOVE_ALL_MESSAGES
-from server.api.utils import chat
-from server.agents import chatbot
+import server.api.utils.mcp as utils_mcp
+import server.mcp.graph as graph
-from common import schema
-from common import logging_config
+from common import schema, logging_config
logger = logging_config.logging.getLogger("endpoints.v1.chat")
@@ -39,7 +38,7 @@ async def chat_post(
) -> ModelResponse:
"""Full Completion Requests"""
last_message = None
- async for chunk in chat.completion_generator(client, request, "completions"):
+ async for chunk in utils_mcp.completion_generator(client, request, "completions"):
last_message = chunk
return last_message
@@ -55,7 +54,7 @@ async def chat_stream(
) -> StreamingResponse:
"""Completion Requests"""
return StreamingResponse(
- chat.completion_generator(client, request, "streams"),
+ utils_mcp.completion_generator(client, request, "streams"),
media_type="application/octet-stream",
)
@@ -67,7 +66,8 @@ async def chat_stream(
)
async def chat_history_clean(client: schema.ClientIdType = Header(default="server")) -> list[ChatMessage]:
"""Delete all Chat History"""
- agent: CompiledStateGraph = chatbot.chatbot_graph
+ agent: CompiledStateGraph = graph.main(list())
+ # agent: CompiledStateGraph = chatbot.chatbot_graph
try:
_ = agent.update_state(
config=RunnableConfig(
@@ -89,7 +89,8 @@ async def chat_history_clean(client: schema.ClientIdType = Header(default="serve
)
async def chat_history_return(client: schema.ClientIdType = Header(default="server")) -> list[ChatMessage]:
"""Return Chat History"""
- agent: CompiledStateGraph = chatbot.chatbot_graph
+ agent: CompiledStateGraph = graph.main(list())
+ # agent: CompiledStateGraph = chatbot.chatbot_graph
try:
state_snapshot = agent.get_state(
config=RunnableConfig(
diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py
index 221c2ccf..fd88cae4 100644
--- a/src/server/api/v1/databases.py
+++ b/src/server/api/v1/databases.py
@@ -8,8 +8,7 @@
import server.api.utils.databases as utils_databases
-from common import schema
-from common import logging_config
+from common import schema, logging_config
logger = logging_config.logging.getLogger("endpoints.v1.databases")
diff --git a/src/server/api/v1/mcp.py b/src/server/api/v1/mcp.py
new file mode 100644
index 00000000..68e40b9a
--- /dev/null
+++ b/src/server/api/v1/mcp.py
@@ -0,0 +1,131 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+This file is being used in APIs, and not the backend.py file.
+"""
+
+# spell-checker:ignore noauth fastmcp healthz
+from fastapi import APIRouter, Request, Depends
+from fastmcp import FastMCP, Client
+
+import server.api.utils.mcp as utils_mcp
+
+from common import logging_config
+from fastapi import APIRouter, Request, Depends
+from fastmcp import FastMCP, Client
+
+import server.api.core.mcp as core_mcp
+
+import common.logging_config as logging_config
+
+logger = logging_config.logging.getLogger("api.v1.mcp")
+
+auth = APIRouter()
+
+
+def get_mcp(request: Request) -> FastMCP:
+ """Get the MCP engine from the app state"""
+ return request.app.state.fastmcp_app
+
+
+@auth.get(
+ "/client",
+ description="Get MCP Client Configuration",
+ response_model=dict,
+)
+async def get_client(server: str = None, port: int = None) -> dict:
+ "Get MCP Client Configuration"
+ return utils_mcp.get_client(server, port)
+
+
+
+@auth.get(
+ "/tools",
+ description="List available MCP tools",
+ response_model=list[dict],
+)
+async def get_tools(mcp_engine: FastMCP = Depends(get_mcp)) -> list[dict]:
+ """List MCP tools"""
+ tools_info = []
+ try:
+ client = Client(mcp_engine)
+ async with client:
+ tools = await client.list_tools()
+ logger.debug("MCP Tools: %s", tools)
+ for tool_object in tools:
+ tools_info.append(tool_object.model_dump())
+ finally:
+ await client.close()
+
+ return tools_info
+
+
+@auth.get(
+ "/resources",
+ description="List MCP resources",
+ response_model=list[dict],
+)
+async def mcp_list_resources(mcp_engine: FastMCP = Depends(get_mcp)) -> list[dict]:
+ """List MCP Resources"""
+ resources_info = []
+ try:
+ client = Client(mcp_engine)
+ async with client:
+ resources = await client.list_resources()
+ logger.debug("MCP Resources: %s", resources)
+ for resources_object in resources:
+ resources_info.append(resources_object.model_dump())
+ finally:
+ await client.close()
+
+ return resources_info
+
+
+@auth.get(
+ "/prompts",
+ description="List MCP prompts",
+ response_model=list[dict],
+)
+async def mcp_list_prompts(mcp_engine: FastMCP = Depends(get_mcp)) -> list[dict]:
+ """List MCP Prompts"""
+ prompts_info = []
+ try:
+ client = Client(mcp_engine)
+ async with client:
+ prompts = await client.list_prompts()
+ logger.debug("MCP Resources: %s", prompts)
+ for prompts_object in prompts:
+ prompts_info.append(prompts_object.model_dump())
+ finally:
+ await client.close()
+
+ return prompts_info
+
+
+# @auth.post("/execute", description="Execute an MCP tool", response_model=dict)
+# async def mcp_execute_tool(request: McpToolCallRequest):
+# """Execute MCP Tool"""
+# mcp_engine = mcp_engine_obj()
+# if not mcp_engine:
+# raise HTTPException(status_code=503, detail="MCP Engine not initialized.")
+# try:
+# result = await mcp_engine.execute_mcp_tool(request.tool_name, request.tool_args)
+# return {"result": result}
+# except Exception as ex:
+# logger.error("Error executing MCP tool: %s", ex)
+# raise HTTPException(status_code=500, detail=str(ex)) from ex
+
+
+# @auth.post("/chat", description="Chat with MCP engine", response_model=dict)
+# async def chat_endpoint(request: ChatRequest):
+# """Chat with MCP Engine"""
+# mcp_engine = mcp_engine_obj()
+# if not mcp_engine:
+# raise HTTPException(status_code=503, detail="MCP Engine not initialized.")
+# try:
+# message_history = request.message_history or [{"role": "user", "content": request.query}]
+# response_text, _ = await mcp_engine.invoke(message_history=message_history)
+# return {"response": response_text}
+# except Exception as ex:
+# logger.error("Error in MCP chat: %s", ex)
+# raise HTTPException(status_code=500, detail=str(ex)) from ex
diff --git a/src/server/api/v1/oci.py b/src/server/api/v1/oci.py
index 56a7f240..3151877e 100644
--- a/src/server/api/v1/oci.py
+++ b/src/server/api/v1/oci.py
@@ -12,8 +12,7 @@
import server.api.utils.oci as utils_oci
import server.api.utils.models as utils_models
-from common import schema
-from common import logging_config
+from common import schema, logging_config
logger = logging_config.logging.getLogger("endpoints.v1.oci")
diff --git a/src/server/api/v1/probes.py b/src/server/api/v1/probes.py
index 6dba7c3d..34a986c6 100644
--- a/src/server/api/v1/probes.py
+++ b/src/server/api/v1/probes.py
@@ -2,13 +2,20 @@
Copyright (c) 2024, 2025, Oracle and/or its affiliates.
Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
"""
-# spell-checker:ignore noauth
-from fastapi import APIRouter
+# spell-checker:ignore noauth fastmcp healthz
+from datetime import datetime
+from fastapi import APIRouter, Request, Depends
+from fastmcp import FastMCP
noauth = APIRouter()
+def get_mcp(request: Request) -> FastMCP:
+ """Get the MCP engine from the app state"""
+ return request.app.state.fastmcp_app
+
+
@noauth.get("/liveness")
async def liveness_probe():
"""Kubernetes liveness probe"""
@@ -19,3 +26,19 @@ async def liveness_probe():
async def readiness_probe():
"""Kubernetes readiness probe"""
return {"status": "ready"}
+
+
+@noauth.get("/mcp/healthz")
+def mcp_healthz(mcp_engine: FastMCP = Depends(get_mcp)):
+ """Check if MCP server is ready."""
+ if mcp_engine is None:
+ return {"status": "not ready"}
+
+ server = mcp_engine.__dict__["_mcp_server"].__dict__
+ return {
+ "status": "ready",
+ "name": server["name"],
+ "version": server["version"],
+ "available_tools": len(getattr(mcp_engine, "available_tools", [])) if mcp_engine else 0,
+ "timestamp": datetime.now().isoformat(),
+ }
diff --git a/src/server/api/v1/prompts.py b/src/server/api/v1/prompts.py
index 64713b61..c6362fb7 100644
--- a/src/server/api/v1/prompts.py
+++ b/src/server/api/v1/prompts.py
@@ -9,8 +9,7 @@
import server.api.core.prompts as core_prompts
-from common import schema
-from common import logging_config
+from common import schema, logging_config
logger = logging_config.logging.getLogger("endpoints.v1.prompts")
diff --git a/src/server/api/v1/settings.py b/src/server/api/v1/settings.py
index be810137..512303f4 100644
--- a/src/server/api/v1/settings.py
+++ b/src/server/api/v1/settings.py
@@ -11,8 +11,7 @@
import server.api.core.settings as core_settings
-from common import schema
-from common import logging_config
+from common import logging_config, schema
logger = logging_config.logging.getLogger("endpoints.v1.settings")
@@ -37,7 +36,7 @@ async def settings_get(
full_config: bool = False,
incl_sensitive: bool = Depends(_incl_sensitive_param),
incl_readonly: bool = Depends(_incl_readonly_param),
-) -> Union[schema.Configuration, schema.Settings]:
+) -> Union[schema.Configuration, schema.Settings, JSONResponse]:
"""Get settings for a specific client by name"""
try:
client_settings = core_settings.get_client_settings(client)
@@ -54,8 +53,13 @@ async def settings_get(
model_configs=config.get("model_configs"),
oci_configs=config.get("oci_configs"),
prompt_configs=config.get("prompt_configs"),
+ mcp_configs=config.get("mcp_configs", None),
)
- return JSONResponse(content=response.model_dump_public(incl_sensitive=incl_sensitive, incl_readonly=incl_readonly))
+ if incl_sensitive or incl_readonly:
+ return JSONResponse(
+ content=response.model_dump_public(incl_sensitive=incl_sensitive, incl_readonly=incl_readonly)
+ )
+ return response
@auth.patch(
@@ -113,12 +117,12 @@ async def load_settings_from_file(
pass
try:
- if not file.filename.endswith(".json"):
+ if not file.filename or not file.filename.endswith(".json"):
raise HTTPException(status_code=400, detail="Settings: Only JSON files are supported.")
contents = await file.read()
config_data = json.loads(contents)
core_settings.load_config_from_json_data(config_data, client)
- return {"message": "Configuration loaded successfully."}
+ return JSONResponse(content={"message": "Configuration loaded successfully."})
except json.JSONDecodeError as ex:
raise HTTPException(status_code=400, detail="Settings: Invalid JSON file.") from ex
except KeyError as ex:
@@ -147,7 +151,7 @@ async def load_settings_from_json(
try:
core_settings.load_config_from_json_data(payload.model_dump(), client)
- return {"message": "Configuration loaded successfully."}
+ return JSONResponse(content={"message": "Configuration loaded successfully."})
except json.JSONDecodeError as ex:
raise HTTPException(status_code=400, detail="Settings: Invalid JSON file.") from ex
except KeyError as ex:
diff --git a/src/server/api/v1/testbed.py b/src/server/api/v1/testbed.py
index 0f38fbce..9d28bfbf 100644
--- a/src/server/api/v1/testbed.py
+++ b/src/server/api/v1/testbed.py
@@ -27,8 +27,7 @@
from server.api.v1 import chat
-from common import schema
-from common import logging_config
+from common import logging_config, schema
logger = logging_config.logging.getLogger("endpoints.v1.testbed")
@@ -90,7 +89,9 @@ async def testbed_testset_qa(
client: schema.ClientIdType = Header(default="server"),
) -> schema.TestSetQA:
"""Get TestSet Q&A"""
- return utils_testbed.get_testset_qa(db_conn=utils_databases.get_client_database(client).connection, tid=tid.upper())
+ return utils_testbed.get_testset_qa(
+ db_conn=utils_databases.get_client_database(client).connection, tid=tid.upper()
+ )
@auth.delete(
diff --git a/src/server/bootstrap/mcp.py b/src/server/bootstrap/mcp.py
new file mode 100644
index 00000000..c958e102
--- /dev/null
+++ b/src/server/bootstrap/mcp.py
@@ -0,0 +1,89 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+
+from typing import List, Optional
+import os
+
+from server.bootstrap.configfile import ConfigStore
+from common.schema import MCPSettings, MCPModelConfig, MCPToolConfig
+from common import logging_config
+
+logger = logging_config.logging.getLogger("bootstrap.mcp")
+
+# Global configuration holders
+MCP_SETTINGS: Optional[MCPSettings] = None
+MCP_MODELS: List[MCPModelConfig] = []
+MCP_TOOLS: List[MCPToolConfig] = []
+
+
+def load_mcp_settings(config: dict) -> None:
+ """Load MCP configuration from config file"""
+ global MCP_SETTINGS, MCP_MODELS, MCP_TOOLS
+
+ # Convert to settings object first
+ mcp_settings = MCPSettings(
+ models=[MCPModelConfig(**model) for model in config.get("models", [])],
+ tools=[MCPToolConfig(**tool) for tool in config.get("tools", [])],
+ default_model=config.get("default_model"),
+ enabled=config.get("enabled", True),
+ )
+
+ # Set globals
+ MCP_SETTINGS = mcp_settings
+ MCP_MODELS = mcp_settings.models
+ MCP_TOOLS = mcp_settings.tools
+
+ logger.info("Loaded %i MCP Models and %i Tools", len(MCP_MODELS), len(MCP_TOOLS))
+
+
+def main() -> MCPSettings:
+ """Bootstrap MCP Configuration"""
+ logger.debug("*** Bootstrapping MCP - Start")
+
+ # Load from ConfigStore if available
+ configuration = ConfigStore.get()
+ if configuration and configuration.mcp_configs:
+ logger.debug("Using MCP configs from ConfigStore")
+ # Convert list of MCPModelConfig objects to MCPSettings
+ mcp_settings = MCPSettings(
+ models=configuration.mcp_configs,
+ tools=[], # No tools in the current schema
+ default_model=configuration.mcp_configs[0].model_id if configuration.mcp_configs else None,
+ enabled=True,
+ )
+ else:
+ # Default MCP configuration
+ mcp_settings = MCPSettings(
+ models=[
+ MCPModelConfig(
+ model_id="llama3.1",
+ service_type="ollama",
+ base_url=os.environ.get("ON_PREM_OLLAMA_URL", "http://localhost:11434"),
+ enabled=True,
+ streaming=False,
+ temperature=1.0,
+ max_tokens=2048,
+ )
+ ],
+ tools=[
+ MCPToolConfig(
+ name="file_reader",
+ description="Read contents of files",
+ parameters={"path": "string", "encoding": "string"},
+ enabled=True,
+ )
+ ],
+ default_model=None,
+ enabled=True,
+ )
+
+ logger.info("Loaded %i MCP Models and %i Tools", len(mcp_settings.models), len(mcp_settings.tools))
+ logger.debug("*** Bootstrapping MCP - End")
+ logger.info("MCP Settings: %s", mcp_settings.model_dump_json())
+ return mcp_settings
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/server/bootstrap/oci.py b/src/server/bootstrap/oci.py
index c40c297f..dd2d7470 100644
--- a/src/server/bootstrap/oci.py
+++ b/src/server/bootstrap/oci.py
@@ -10,9 +10,9 @@
from server.bootstrap.configfile import ConfigStore
-from common import logging_config
from common.schema import OracleCloudSettings
+from common import logging_config
logger = logging_config.logging.getLogger("bootstrap.oci")
diff --git a/src/server/mcp/__init__.py b/src/server/mcp/__init__.py
new file mode 100644
index 00000000..8cb3ea53
--- /dev/null
+++ b/src/server/mcp/__init__.py
@@ -0,0 +1,73 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+# spell-checker:ignore fastapi fastmcp
+
+import importlib
+import pkgutil
+
+from fastapi import APIRouter
+from fastmcp import FastMCP
+
+from common import logging_config
+
+logger = logging_config.logging.getLogger("mcp.__init__.py")
+
+
+async def _discover_and_register(
+ package: str,
+ mcp: FastMCP = None,
+ auth: APIRouter = None,
+):
+ """Import all modules in a package and call their register function."""
+ try:
+ pkg = importlib.import_module(package)
+ except ImportError:
+ logger.warning("Package %s not found, skipping.", package)
+ return
+
+ for module_info in pkgutil.walk_packages(pkg.__path__, prefix=f"{package}."):
+ if module_info.name.endswith("__init__"):
+ continue
+
+ try:
+ module = importlib.import_module(module_info.name)
+ except Exception as ex:
+ logger.error("Failed to import %s: %s", module_info.name, ex)
+ continue
+
+ # Decide what to register based on available functions
+ if hasattr(module, "register"):
+ logger.info("Registering via %s.register()", module_info.name)
+ if ".tools." in module.__name__:
+ await module.register(mcp, auth)
+ if ".proxies." in module.__name__:
+ await module.register(mcp)
+ if ".prompts." in module.__name__:
+ await module.register(mcp)
+ # elif hasattr(module, "register_tool"):
+ # logger.info("Registering tool via %s.register_tool()", module_info.name)
+ # module.register_tool(mcp, auth)
+ # elif hasattr(module, "register_prompt"):
+ # logger.info("Registering prompt via %s.register_prompt()", module_info.name)
+ # module.register_prompt(mcp)
+ # elif hasattr(module, "register_resource"):
+ # logger.info("Registering resource via %s.register_resource()", module_info.name)
+ # module.register_resource(mcp)
+ # elif hasattr(module, "register_proxy"):
+ # logger.info("Registering proxy via %s.register_resource()", module_info.name)
+ # module.register_resource(mcp)
+ else:
+ logger.debug("No register function in %s, skipping.", module_info.name)
+
+
+async def register_all_mcp(mcp: FastMCP, auth: APIRouter):
+ """
+ Auto-discover and register all MCP tools, prompts, resources, and proxies.
+ """
+ logger.info("Starting Registering MCP Components")
+ await _discover_and_register("server.mcp.tools", mcp=mcp, auth=auth)
+ await _discover_and_register("server.mcp.proxies", mcp=mcp)
+ # await _discover_and_register("server.mcp.prompts", mcp=mcp)
+ logger.info("Finished Registering MCP Components")
diff --git a/src/server/mcp/graph.py b/src/server/mcp/graph.py
new file mode 100644
index 00000000..00664032
--- /dev/null
+++ b/src/server/mcp/graph.py
@@ -0,0 +1,171 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+# spell-checker:ignore ainvoke checkpointer
+
+from datetime import datetime, timezone
+
+from langchain_core.messages import SystemMessage, ToolMessage
+from langchain_core.runnables import RunnableConfig
+
+from langgraph.graph import StateGraph, MessagesState, START, END
+from langgraph.prebuilt import ToolNode, tools_condition
+
+from common.schema import ChatResponse, ChatUsage, ChatChoices, ChatMessage
+from launch_server import graph_memory
+
+import common.logging_config as logging_config
+
+logger = logging_config.logging.getLogger("mcp.graph")
+
+
+#############################################################################
+# AGENT STATE
+#############################################################################
+class OptimizerState(MessagesState):
+ """Establish our Agent State Machine"""
+
+ final_response: ChatResponse # OpenAI Response
+ cleaned_messages: list # Messages w/o VS Results
+
+
+#############################################################################
+# NODES and EDGES
+#############################################################################
+def respond(state: OptimizerState, config: RunnableConfig) -> ChatResponse:
+ """Respond in OpenAI Compatible return"""
+ ai_message = state["messages"][-1]
+ logger.debug("Formatting to OpenAI compatible response: %s", repr(ai_message))
+ if "model_name" in ai_message.response_metadata:
+ model_id = ai_message.response_metadata["model_name"]
+ ai_metadata = ai_message
+ else:
+ logger.debug("Using Metadata from: %s", repr(ai_metadata))
+ model_id = config["metadata"]["ll_model"]
+ ai_metadata = state["messages"][1]
+
+ finish_reason = ai_metadata.response_metadata.get("finish_reason", "stop")
+ if finish_reason == "COMPLETE":
+ finish_reason = "stop"
+ elif finish_reason == "MAX_TOKENS":
+ finish_reason = "length"
+
+ openai_response = ChatResponse(
+ id=ai_message.id,
+ created=int(datetime.now(timezone.utc).timestamp()),
+ model=model_id,
+ usage=ChatUsage(
+ prompt_tokens=ai_metadata.response_metadata.get("token_usage", {}).get("prompt_tokens", -1),
+ completion_tokens=ai_metadata.response_metadata.get("token_usage", {}).get("completion_tokens", -1),
+ total_tokens=ai_metadata.response_metadata.get("token_usage", {}).get("total_tokens", -1),
+ ),
+ choices=[
+ ChatChoices(
+ index=0,
+ message=ChatMessage(
+ role="ai",
+ content=ai_message.content,
+ additional_kwargs=ai_metadata.additional_kwargs,
+ response_metadata=ai_metadata.response_metadata,
+ ),
+ finish_reason=finish_reason,
+ logprobs=None,
+ )
+ ],
+ )
+ return {"final_response": openai_response}
+
+
+async def client(state: OptimizerState, config: RunnableConfig) -> OptimizerState:
+ """Get messages from state based on Thread ID"""
+ logger.debug("Initializing OptimizerState")
+ messages = get_messages(state, config)
+
+ return {"cleaned_messages": messages}
+
+
+#############################################################################
+def get_messages(state: OptimizerState, config: RunnableConfig) -> list:
+ """Return a list of messages that will be passed to the model for completion
+ Leave the state as is for GUI functionality"""
+ use_history = config["metadata"]["use_history"]
+
+ # If user decided for no history, only take the last message
+ state_messages = state["messages"] if use_history else state["messages"][-1:]
+
+ messages = []
+ for msg in state_messages:
+ if isinstance(msg, SystemMessage):
+ continue
+ if isinstance(msg, ToolMessage):
+ if messages: # Check if there are any messages in the list
+ messages.pop() # Remove the last appended message
+ continue
+ messages.append(msg)
+
+ # # insert the system prompt; remaining messages cleaned
+ # if config["metadata"]["sys_prompt"].prompt:
+ # messages.insert(0, SystemMessage(content=config["metadata"]["sys_prompt"].prompt))
+
+ return messages
+
+
+def should_continue(state: OptimizerState):
+ """Determine if graph should continue to tools"""
+ messages = state["messages"]
+ last_message = messages[-1]
+ if last_message.tool_calls:
+ return "tools"
+ return END
+
+
+# Define call_model function
+async def call_model(state: OptimizerState, config: RunnableConfig):
+ """Invoke the model"""
+ try:
+ model = config["configurable"].get("ll_model", None)
+ messages = state["messages"]
+ response = await model.ainvoke(messages)
+ return {"messages": [response]}
+ except AttributeError as ex:
+ # The model is not in our RunnableConfig
+ return {"messages": [f"I'm sorry; {ex}"]}
+
+
+# #############################################################################
+# # GRAPH
+# #############################################################################
+def main(tools: list):
+ """Define the graph with MCP tool nodes"""
+ # Build the graph
+ workflow = StateGraph(OptimizerState)
+
+ # Define the nodes
+ workflow.add_node("client", client)
+ workflow.add_node("call_model", call_model)
+ workflow.add_node("tools", ToolNode(tools))
+ workflow.add_node("respond", respond)
+
+ # Add Edges
+ workflow.add_edge(START, "client")
+ workflow.add_edge("client", "call_model")
+ workflow.add_conditional_edges(
+ "call_model",
+ should_continue,
+ )
+ workflow.add_edge("tools", "call_model")
+ workflow.add_edge("call_model", "respond")
+ workflow.add_edge("respond", END)
+
+ # Compile the graph and return it
+ mcp_graph = workflow.compile(checkpointer=graph_memory)
+ logger.debug("Chatbot Graph Built with tools: %s", tools)
+ ## This will output the Graph in ascii; don't deliver uncommented
+ # mcp_graph.get_graph(xray=True).print_ascii()
+
+ return mcp_graph
+
+
+if __name__ == "__main__":
+ main(list())
diff --git a/src/server/mcp/prompts/__init__.py b/src/server/mcp/prompts/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/server/mcp/prompts/optimizer.py b/src/server/mcp/prompts/optimizer.py
new file mode 100644
index 00000000..52e03226
--- /dev/null
+++ b/src/server/mcp/prompts/optimizer.py
@@ -0,0 +1,21 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+
+# pylint: disable=unused-argument
+# spell-checker:ignore fastmcp
+from fastmcp.prompts.prompt import PromptMessage, TextContent
+
+
+# Basic prompt returning a string (converted to user message automatically)
+async def register(mcp):
+ """Register Out-of-Box Prompts"""
+ optimizer_tags = {"source", "optimizer"}
+
+ @mcp.prompt(name="basic-example-chatbot", tags=optimizer_tags)
+ def basic_example() -> PromptMessage:
+ """Basic system prompt for chatbot."""
+
+ content = "You are a friendly, helpful assistant."
+ return PromptMessage(role="system", content=TextContent(type="text", text=content))
diff --git a/src/server/mcp/proxies/__init__.py b/src/server/mcp/proxies/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/server/mcp/proxies/sqlcl.py b/src/server/mcp/proxies/sqlcl.py
new file mode 100644
index 00000000..a9f3e445
--- /dev/null
+++ b/src/server/mcp/proxies/sqlcl.py
@@ -0,0 +1,72 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+# spell-checker:ignore sqlcl fastmcp connmgr noupdates savepwd
+
+import os
+import shutil
+import subprocess
+
+import server.api.utils.databases as utils_databases
+
+from common import logging_config
+
+logger = logging_config.logging.getLogger("mcp.proxies.sqlcl")
+
+
+async def register(mcp):
+ """Register the SQLcl MCP Server as Local (via Proxy)"""
+ tool_name = "SQLclProxy"
+
+ sqlcl_binary = shutil.which("sql")
+ if sqlcl_binary:
+ env_vars = os.environ.copy()
+ env_vars["TNS_ADMIN"] = os.getenv("TNS_ADMIN", "tns_admin")
+ config = {
+ "mcpServers": {
+ tool_name: {
+ "name": tool_name,
+ "command": f"{sqlcl_binary}",
+ "args": ["-mcp", "-daemon", "-thin", "-noupdates"],
+ "env": env_vars,
+ }
+ }
+ }
+ databases = utils_databases.get_databases(validate=False)
+ for database in databases:
+ # Start sql in no-login mode
+ try:
+ proc = subprocess.Popen(
+ [sqlcl_binary, "/nolog"],
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ env=env_vars,
+ )
+
+ # Prepare commands: connect, then exit
+ commands = [
+ f"connmgr delete -conn OPTIMIZER_{database.name}",
+ (
+ f"conn -savepwd -save OPTIMIZER_{database.name} "
+ f"-user {database.user} -password {database.password} "
+ f"-url {database.dsn}"
+ ),
+ "exit",
+ ]
+
+ # Send commands joined by newlines
+ proc.communicate("\n".join(commands) + "\n")
+ logger.info("Established Connection Store for: %s", database.name)
+ except subprocess.SubprocessError as ex:
+ logger.error("Failed to create connection store: %s", ex)
+ except Exception as ex:
+ logger.error("Unexpected error creating connection store: %s", ex)
+
+ # Create a proxy to the configured server (auto-creates ProxyClient)
+ proxy = mcp.as_proxy(config, name=tool_name)
+ mcp.mount(proxy, as_proxy=False, prefix="sqlcl")
+ else:
+ logger.warning("Not enabling SQLcl MCP server, sqlcl not found in PATH.")
diff --git a/src/server/mcp/resources/__init__.py b/src/server/mcp/resources/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/server/mcp/tools/__init__.py b/src/server/mcp/tools/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/server/mcp/tools/say_hello.py b/src/server/mcp/tools/say_hello.py
new file mode 100644
index 00000000..b581cd90
--- /dev/null
+++ b/src/server/mcp/tools/say_hello.py
@@ -0,0 +1,6 @@
+async def register(mcp, auth):
+ @mcp.tool(name="optimizer_greet")
+ @auth.get("/hello", operation_id="say_hello")
+ def greet(name: str = "World") -> str:
+ """Say hello to someone."""
+ return f"Hello, {name}!"
diff --git a/src/server/mcp_bak/register_mcp_servers.py b/src/server/mcp_bak/register_mcp_servers.py
new file mode 100644
index 00000000..c08a59b2
--- /dev/null
+++ b/src/server/mcp_bak/register_mcp_servers.py
@@ -0,0 +1,36 @@
+from fastapi import FastAPI, APIRouter, Request
+from fastapi.responses import JSONResponse, PlainTextResponse
+from mcp.server import Server
+import json
+
+def mount_mcp(router: APIRouter, prefix: str, mcp_server: Server):
+ @router.get(f"{prefix}/.well-known/mcp.json")
+ async def manifest():
+ return JSONResponse(content=mcp_server.manifest.dict())
+
+ @router.post(f"{prefix}/mcp")
+ async def mcp_api(request: Request):
+ body = await request.body()
+ resp = mcp_server.handle_http(body)
+ try:
+ return JSONResponse(content=json.loads(resp))
+ except Exception:
+ return PlainTextResponse(content=resp)
+
+def register_mcp_servers(app: FastAPI):
+ # Create routers for MCP endpoints
+ mcp_router = APIRouter()
+
+ # Define MCP servers
+ mcp_sqlcl = Server(name="Built-in SQLcl MCP Server")
+
+ # Example tools
+ @mcp.tool()
+ def greet(name: str) -> str:
+ return f"Hello from MCP Server One, {name}!"
+
+ # Mount MCP servers into the router under prefixes
+ mount_mcp(app, "/mcp_sqlcl", mcp_sqlcl)
+
+ # Include the MCP router into the main app
+ app.include_router(mcp_router)
\ No newline at end of file
diff --git a/src/server/mcp_bak/server/archive_mcp.py b/src/server/mcp_bak/server/archive_mcp.py
new file mode 100644
index 00000000..d38a091f
--- /dev/null
+++ b/src/server/mcp_bak/server/archive_mcp.py
@@ -0,0 +1,182 @@
+import json
+import os
+from dotenv import load_dotenv
+import arxiv
+from typing import List
+from mcp.server.fastmcp import FastMCP
+import textwrap
+
+# --- Configuration and Setup ---
+load_dotenv()
+PAPER_DIR = "papers"
+# Initialize FastMCP server with a name
+mcp = FastMCP("research")
+_paper_cache = {}
+
+# --- Tool Definitions ---
+
+@mcp.tool()
+def search_papers(topic: str, max_results: int = 5) -> List[str]:
+ """
+ Searches for papers on arXiv based on a topic and saves their metadata.
+
+ Args:
+ topic (str): The topic to search for.
+ max_results (int): Maximum number of results to retrieve.
+
+ Returns:
+ List[str]: A list of the paper IDs found and saved.
+ """
+ client_arxiv = arxiv.Client()
+ search = arxiv.Search(
+ query=topic,
+ max_results=max_results,
+ sort_by=arxiv.SortCriterion.Relevance
+ )
+ papers = list(client_arxiv.results(search))
+
+ if not papers:
+ # It's good practice to print feedback on the server side
+ print(f"Server: No papers found for topic '{topic}'")
+ return []
+
+ path = os.path.join(PAPER_DIR, topic.lower().replace(" ", "_"))
+ os.makedirs(path, exist_ok=True)
+ file_path = os.path.join(path, "papers_info.json")
+
+ try:
+ with open(file_path, "r") as json_file:
+ papers_info = json.load(json_file)
+ except (FileNotFoundError, json.JSONDecodeError):
+ papers_info = {}
+
+ paper_ids = []
+ for paper in papers:
+ paper_id = paper.get_short_id()
+ paper_ids.append(paper_id)
+ papers_info[paper_id] = {
+ 'title': paper.title,
+ 'authors': [author.name for author in paper.authors],
+ 'summary': paper.summary,
+ 'pdf_url': paper.pdf_url,
+ 'published': str(paper.published.date())
+ }
+
+ with open(file_path, "w") as json_file:
+ json.dump(papers_info, json_file, indent=2)
+
+ print(f"Server: Saved {len(paper_ids)} papers to {file_path}")
+ return paper_ids
+
+@mcp.tool()
+def extract_info(paper_id: str) -> str:
+ """
+ Retrieves saved information for a specific paper ID from all topics.
+ Uses an in-memory cache for performance.
+
+ Args:
+ paper_id (str): The ID of the paper to look for.
+
+ Returns:
+ str: JSON string with paper information if found, else an error message.
+ """
+ # 1. First, check the cache for an exact match
+ if paper_id in _paper_cache:
+ return json.dumps(_paper_cache[paper_id], indent=2)
+
+ # 2. If not in cache, perform the expensive file search (your original logic)
+ for item in os.listdir(PAPER_DIR):
+ item_path = os.path.join(PAPER_DIR, item)
+ if os.path.isdir(item_path):
+ file_path = os.path.join(item_path, "papers_info.json")
+ if os.path.isfile(file_path):
+ try:
+ with open(file_path, "r") as json_file:
+ papers_info = json.load(json_file)
+
+ # Search logic (can be simplified if we populate cache first)
+ for key, value in papers_info.items():
+ # Add every paper from this file to the cache to avoid re-reading this file
+ if key not in _paper_cache:
+ _paper_cache[key] = value
+
+ except (FileNotFoundError, json.JSONDecodeError):
+ continue
+
+ # 3. Now that the cache is populated from relevant files, check again.
+ # This handles version differences as well.
+ if paper_id in _paper_cache:
+ return json.dumps(_paper_cache[paper_id], indent=2)
+
+ base_id = paper_id.split('v')[0]
+ for key, value in _paper_cache.items():
+ if key.startswith(base_id):
+ return json.dumps(value, indent=2)
+
+ return f"Error: No saved information found for paper ID {paper_id}."
+
+# --- Resource Definitions ---
+
+@mcp.resource("papers://folders")
+def get_available_folders() -> str:
+ """Lists all available topic folders that contain saved paper information."""
+ print(f"Server: Listing available topic folders in {PAPER_DIR}")
+ folders = []
+ if os.path.exists(PAPER_DIR):
+ for topic_dir in os.listdir(PAPER_DIR):
+ if os.path.isdir(os.path.join(PAPER_DIR, topic_dir)):
+ folders.append(topic_dir)
+
+ content = "# Available Research Topics\n\n"
+ if folders:
+ content += "You can retrieve info for any of these topics using `@`.\n\n"
+ for folder in folders:
+ content += f"- `{folder}`\n"
+ else:
+ content += "No topic folders found. Use `search_papers` to create one."
+ print(f"Server: Found {len(folders)} topic folders.")
+ return content
+
+@mcp.resource("papers://{topic}")
+def get_topic_papers(topic: str) -> str:
+ """Gets detailed information about all saved papers for a specific topic."""
+ print(f"Server: Retrieving papers for topic '{topic}'")
+ topic_dir = topic.lower().replace(" ", "_")
+ papers_file = os.path.join(PAPER_DIR, topic_dir, "papers_info.json")
+
+ if not os.path.exists(papers_file):
+ return f"# No papers found for topic: {topic}"
+
+ with open(papers_file, 'r') as f:
+ papers_data = json.load(f)
+
+ content = f"# Papers on {topic.replace('_', ' ').title()}\n\n"
+ for paper_id, info in papers_data.items():
+ content += f"## {info['title']} (`{paper_id}`)\n"
+ content += f"- **Authors**: {', '.join(info['authors'])}\n"
+ content += f"- **Summary**: {info['summary'][:200]}...\n---\n"
+ print(f"Server: Found {len(papers_data)} papers for topic '{topic}'")
+ return content
+
+# --- Prompt Definition ---
+
+@mcp.prompt()
+def generate_search_prompt(topic: str) -> str:
+ """Generates a system prompt to guide an AI in researching a topic."""
+ return textwrap.dedent(f"""
+ You are a research assistant. Your goal is to provide a comprehensive overview of a topic.
+ When asked about '{topic}', follow these steps:
+ 1. Use the `search_papers` tool to find relevant papers.
+ 2. For each paper ID returned, use the `extract_info` tool to get its details.
+ 3. Synthesize the information from all papers into a cohesive summary.
+ 4. Present the key findings, common themes, and any differing conclusions.
+ Do not present the raw JSON. Format the final output for readability.
+ """)
+
+# --- Main Execution Block ---
+
+if __name__ == "__main__":
+ # This is the original, simple, and correct way to run the server.
+ # It will not crash.
+ print("Research MCP Server running on stdio...")
+ mcp.run(transport='stdio')
diff --git a/src/server/mcp_bak/server_config.json b/src/server/mcp_bak/server_config.json
new file mode 100644
index 00000000..3d8b0321
--- /dev/null
+++ b/src/server/mcp_bak/server_config.json
@@ -0,0 +1,20 @@
+{
+ "mcpServers": {
+ "filesystem": {
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-filesystem",
+ "."
+ ]
+ },
+ "research": {
+ "command": "python3",
+ "args": ["server/mcp/server/archive_mcp.py"]
+ },
+ "fetch": {
+ "command": "python3",
+ "args": ["-m", "mcp_server_fetch"]
+ }
+ }
+}
diff --git a/tests/conftest.py b/tests/conftest.py
index 3c8b94f6..ec1b5bf9 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,6 +8,7 @@
# pylint: disable=import-outside-toplevel
import os
+import asyncio
# This contains all the environment variables we consume on startup (add as required)
# Used to clear testing environment from users env; Do before any additional imports
@@ -72,7 +73,7 @@ def client():
# Lazy Load
from launch_server import create_app
- app = create_app()
+ app = asyncio.run(create_app())
return TestClient(app)
@@ -110,14 +111,14 @@ def is_port_in_use(port):
server_process = subprocess.Popen(cmd, cwd="src")
- # Wait for server to be ready (up to 30 seconds)
- max_wait = 30
+ # Wait for server to be ready
+ max_wait = 60
start_time = time.time()
while not is_port_in_use(8015):
if time.time() - start_time > max_wait:
server_process.terminate()
server_process.wait()
- raise TimeoutError("Server failed to start within 30 seconds")
+ raise TimeoutError(f"Server failed to start within {max_wait} seconds")
time.sleep(0.5)
yield server_process
@@ -132,7 +133,7 @@ def app_test(auth_headers):
"""Establish Streamlit State for Client to Operate"""
def _app_test(page):
- at = AppTest.from_file(page, default_timeout=30)
+ at = AppTest.from_file(page, default_timeout=60)
at.session_state.server = {
"key": os.environ.get("API_SERVER_KEY"),
"url": os.environ.get("API_SERVER_URL"),
diff --git a/tests/unit/server/api/utils/test_models.py b/tests/unit/server/api/utils/test_models.py
new file mode 100644
index 00000000..04728f13
--- /dev/null
+++ b/tests/unit/server/api/utils/test_models.py
@@ -0,0 +1,36 @@
+"""
+Copyright (c) 2024, 2025, Oracle and/or its affiliates.
+Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
+"""
+
+# spell-checker: disable
+import os
+import pytest
+
+import server.api.core.models as core_models
+import server.api.utils.models as utils_models
+
+os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
+os.environ["LITELLM_DISABLE_SPEND_LOGS"] = "True"
+os.environ["LITELLM_DISABLE_SPEND_UPDATES"] = "True"
+os.environ["LITELLM_DISABLE_END_USER_COST_TRACKING"] = "True"
+os.environ["LITELLM_DROP_PARAMS"] = "True"
+os.environ["LITELLM_DROP_PARAMS"] = "True"
+
+@pytest.fixture(name="models_list")
+def _models_list():
+ model_objects = core_models.get_model()
+ for obj in model_objects:
+ obj.enabled = True
+ return model_objects
+
+
+def test_get_litellm_client(models_list):
+ """Testing LiteLLM Functionality"""
+ assert isinstance(models_list, list)
+ assert len(models_list) > 0
+
+ for model in models_list:
+ print(f"My Model: {model}")
+ if model.id == "mxbai-embed-large":
+ utils_models.get_litellm_client(model.dict())