diff --git a/.gitignore b/.gitignore index 4e9fd655..9fa693f4 100644 --- a/.gitignore +++ b/.gitignore @@ -68,4 +68,3 @@ spring_ai/drop.sql src/client/spring_ai/target/classes/* api_server_key .env - diff --git a/src/client/content/chatbot.py.mcp b/src/client/content/chatbot.py.mcp new file mode 100644 index 00000000..c79146f7 --- /dev/null +++ b/src/client/content/chatbot.py.mcp @@ -0,0 +1,300 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +This file merges the Streamlit Chatbot GUI with the MCPClient for a complete, +runnable example demonstrating their integration. +""" + +# spell-checker:ignore streamlit, oraclevs, selectai, langgraph, prebuilt +import asyncio +import inspect +import json +import base64 + +import streamlit as st +from streamlit import session_state as state + +from client.content.config.tabs.models import get_models + +import client.utils.st_common as st_common +import client.utils.api_call as api_call + +from client.utils.st_footer import render_chat_footer +import common.logging_config as logging_config +from client.mcp.client import MCPClient +from pathlib import Path + +logger = logging_config.logging.getLogger("client.content.chatbot") + + +############################################################################# +# Functions +############################################################################# +def show_vector_search_refs(context): + """When Vector Search Content Found, show the references""" + st.markdown("**References:**") + ref_src = set() + ref_cols = st.columns([3, 3, 3]) + # Create a button in each column + for i, (ref_col, chunk) in enumerate(zip(ref_cols, context[0])): + with ref_col.popover(f"Reference: {i + 1}"): + chunk = context[0][i] + logger.debug("Chunk Content: %s", chunk) + st.subheader("Reference Text", divider="red") + st.markdown(chunk["page_content"]) + try: + ref_src.add(chunk["metadata"]["filename"]) + st.subheader("Metadata", divider="red") + st.markdown(f"File: {chunk['metadata']['source']}") + st.markdown(f"Chunk: {chunk['metadata']['page']}") + except KeyError: + logger.error("Chunk Metadata NOT FOUND!!") + + for link in ref_src: + st.markdown("- " + link) + st.markdown(f"**Notes:** Vector Search Query - {context[1]}") + + +############################################################################# +# MAIN +############################################################################# +async def main() -> None: + """Streamlit GUI""" + try: + get_models() + except api_call.ApiError: + st.stop() + ######################################################################### + # Sidebar Settings + ######################################################################### + ll_models_enabled = st_common.enabled_models_lookup("ll") + if not ll_models_enabled: + st.error("No language models are configured and/or enabled. Disabling Client.", icon="🛑") + st.stop() + state.enable_client = True + st_common.tools_sidebar() + st_common.history_sidebar() + st_common.ll_sidebar() + st_common.selectai_sidebar() + st_common.vector_search_sidebar() + if not state.enable_client: + st.stop() + + ######################################################################### + # Chatty-Bot Centre + ######################################################################### + + if "messages" not in state: + state.messages = [] + + st.chat_message("ai").write("Hello, how can I help you?") + + for message in state.messages: + role = message.get("role") + display_role = "" + if role in ("human", "user"): + display_role = "human" + elif role in ("ai", "assistant"): + if not message.get("content") and not message.get("tool_trace"): + continue + display_role = "assistant" + else: + continue + + with st.chat_message(display_role): + if "tool_trace" in message and message["tool_trace"]: + for tool_call in message["tool_trace"]: + with st.expander(f"🛠️ **Tool Call:** `{tool_call['name']}`", expanded=False): + st.text("Arguments:") + st.code(json.dumps(tool_call.get("args", {}), indent=2), language="json") + if "error" in tool_call: + st.text("Error:") + st.error(tool_call["error"]) + else: + st.text("Result:") + st.code(tool_call.get("result", ""), language="json") + if message.get("content"): + # Display file attachments if present + if "attachments" in message and message["attachments"]: + for file in message["attachments"]: + # Show appropriate icon based on file type + if file["type"].startswith("image/"): + st.image(file["preview"], use_container_width=True) + st.markdown(f"🖼️ **{file['name']}** ({file['size'] // 1024} KB)") + elif file["type"] == "application/pdf": + st.markdown(f"📄 **{file['name']}** ({file['size'] // 1024} KB)") + elif file["type"] in ("text/plain", "text/markdown"): + st.markdown(f"📝 **{file['name']}** ({file['size'] // 1024} KB)") + else: + st.markdown(f"📎 **{file['name']}** ({file['size'] // 1024} KB)") + + # Display message content - handle both string and list formats + content = message.get("content") + if isinstance(content, list): + # Extract and display only text parts + text_parts = [part["text"] for part in content if part["type"] == "text"] + st.markdown("\n".join(text_parts)) + else: + st.markdown(content) + + sys_prompt = state.client_settings["prompts"]["sys"] + render_chat_footer() + + if human_request := st.chat_input( + f"Ask your question here... (current prompt: {sys_prompt})", + accept_file=True, + file_type=["jpg", "jpeg", "png", "pdf", "txt", "docx"], + key=f"chat_input_{len(state.messages)}", + ): + # Process message with potential file attachments + message = {"role": "user", "content": human_request.text} + + # Handle file attachments + if hasattr(human_request, "files") and human_request.files: + # Store file information separately from content + message["attachments"] = [] + for file in human_request.files: + file_bytes = file.read() + file_b64 = base64.b64encode(file_bytes).decode("utf-8") + message["attachments"].append( + { + "name": file.name, + "type": file.type, + "size": len(file_bytes), + "data": file_b64, + "preview": f"data:{file.type};base64,{file_b64}" if file.type.startswith("image/") else None, + } + ) + + state.messages.append(message) + st.rerun() + if state.messages and state.messages[-1]["role"] == "user": + try: + with st.chat_message("ai"): + with st.spinner("Thinking..."): + client_settings_for_request = state.client_settings.copy() + model_id = client_settings_for_request.get("ll_model", {}).get("model") + if model_id: + all_model_configs = st_common.enabled_models_lookup("ll") + model_config = all_model_configs.get(model_id, {}) + if "api_key" in model_config: + if "ll_model" not in client_settings_for_request: + client_settings_for_request["ll_model"] = {} + client_settings_for_request["ll_model"]["api_key"] = model_config["api_key"] + + # Prepare message history for backend + message_history = [] + for msg in state.messages: + # Create a copy of the message + processed_msg = msg.copy() + + # If there are attachments, include them in the content + if "attachments" in msg and msg["attachments"]: + # Start with the text content + text_content = msg["content"] + + # Handle list content format (from OpenAI API) + if isinstance(text_content, list): + text_parts = [part["text"] for part in text_content if part["type"] == "text"] + text_content = "\n".join(text_parts) + + # Create a list to hold structured content parts + content_list = [{"type": "text", "text": text_content}] + + non_image_references = [] + for attachment in msg["attachments"]: + if attachment["type"].startswith("image/"): + # Only add image URLs for user messages + if msg["role"] in ("human", "user"): + # Normalize image MIME types for compatibility + mime_type = attachment["type"] + if mime_type == "image/jpg": + mime_type = "image/jpeg" + + content_list.append( + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{attachment['data']}", + "detail": "low", + }, + } + ) + else: + # Handle non-image files as text references + non_image_references.append( + f"\n[File: {attachment['name']} ({attachment['size'] // 1024} KB)]" + ) + + # If there were non-image files, append their references to the main text part + if non_image_references: + content_list[0]["text"] += "".join(non_image_references) + + processed_msg["content"] = content_list + # Convert list content to string format + elif isinstance(msg.get("content"), list): + text_parts = [part["text"] for part in msg["content"] if part["type"] == "text"] + processed_msg["content"] = str("\n".join(text_parts)) + # Otherwise, ensure content is a string + else: + processed_msg["content"] = str(msg.get("content", "")) + + message_history.append(processed_msg) + + async with MCPClient(client_settings=client_settings_for_request) as mcp_client: + final_text, tool_trace, new_history = await mcp_client.invoke(message_history=message_history) + + # Update the history for display. + # Keep the original message structure with attachments + for i in range(len(new_history) - 1, -1, -1): + if new_history[i].get("role") == "assistant": + # Preserve any attachments from the user message + user_message = state.messages[-1] + if "attachments" in user_message: + new_history[-1]["attachments"] = user_message["attachments"] + + new_history[i]["content"] = final_text + new_history[i]["tool_trace"] = tool_trace + break + + state.messages = new_history + st.rerun() + + except Exception as e: + logger.error("Exception during invoke call:", exc_info=True) + # Extract just the error message + error_msg = str(e) + + # Check if it's a file-related error + if "file" in error_msg.lower() or "image" in error_msg.lower() or "content" in error_msg.lower(): + st.error(f"Error: {error_msg}") + + # Add a button to remove files and retry + if st.button("Remove files and retry", key="remove_files_retry"): + # Remove attachments from the latest message + if state.messages and "attachments" in state.messages[-1]: + del state.messages[-1]["attachments"] + st.rerun() + else: + st.error(f"Error: {error_msg}") + + if st.button("Retry", key="reload_chatbot_error"): + if state.messages and state.messages[-1]["role"] == "user": + state.messages.pop() + st.rerun() + + +if __name__ == "__main__" or ("page" in inspect.stack()[1].filename if inspect.stack() else False): + try: + asyncio.run(main()) + except ValueError as ex: + logger.exception("Bug detected: %s", ex) + st.error("It looks like you found a bug; please open an issue", icon="🛑") + st.stop() + except IndexError as ex: + logger.exception("Unable to contact the server: %s", ex) + st.error("Unable to contact the server, is it running?", icon="🚨") + if st.button("Retry", key="reload_chatbot"): + st_common.clear_state_key("user_client") + st.rerun() diff --git a/src/client/content/config/config.py b/src/client/content/config/config.py index 84a664ba..98a1a8f1 100644 --- a/src/client/content/config/config.py +++ b/src/client/content/config/config.py @@ -11,6 +11,7 @@ from client.content.config.tabs.oci import get_oci, display_oci from client.content.config.tabs.databases import get_databases, display_databases from client.content.config.tabs.models import get_models, display_models +from client.content.config.tabs.mcp import get_mcp, display_mcp def main() -> None: @@ -20,6 +21,7 @@ def main() -> None: get_databases() get_models() get_oci() + get_mcp() tabs_list = [] if not state.disabled["settings"]: @@ -30,6 +32,8 @@ def main() -> None: tabs_list.append("🤖 Models") if not state.disabled["oci_cfg"]: tabs_list.append("☁️ OCI") + if not state.disabled["mcp_cfg"]: + tabs_list.append("🔗 MCP") # Only create tabs if there is at least one tab_index = 0 @@ -53,6 +57,10 @@ def main() -> None: with tabs[tab_index]: display_oci() tab_index += 1 + if not state.disabled["mcp_cfg"]: + with tabs[tab_index]: + display_mcp() + tab_index += 1 if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename: diff --git a/src/client/content/config/tabs/databases.py b/src/client/content/config/tabs/databases.py index 3bf4ee32..cab6b596 100644 --- a/src/client/content/config/tabs/databases.py +++ b/src/client/content/config/tabs/databases.py @@ -26,7 +26,6 @@ def get_databases(force: bool = False) -> None: """Get Databases from API Server""" if force or "database_configs" not in state or not state.database_configs: try: - logger.info("Refreshing state.database_configs") # Validation will be done on currently configured client database # validation includes new vector_stores, etc. client_database = state.client_settings.get("database", {}).get("alias", {}) diff --git a/src/client/content/config/tabs/mcp.py b/src/client/content/config/tabs/mcp.py new file mode 100644 index 00000000..75621ba5 --- /dev/null +++ b/src/client/content/config/tabs/mcp.py @@ -0,0 +1,188 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" + +# spell-checker:ignore selectbox healthz +import json + +import streamlit as st +from streamlit import session_state as state + +from client.utils import api_call, st_common + +from common import logging_config + +logger = logging_config.logging.getLogger("client.content.config.tabs.mcp") + + +################################### +# Functions +################################### +def get_mcp_status() -> dict: + """Get MCP Status""" + try: + return api_call.get(endpoint="v1/mcp/healthz") + except api_call.ApiError as ex: + logger.error("Unable to get MCP Status: %s", ex) + return {} + + +def get_mcp_client() -> dict: + """Get MCP Client Configuration""" + try: + params = {"server": {state.server["url"]}, "port": {state.server["port"]}} + mcp_client = api_call.get(endpoint="v1/mcp/client", params=params) + return json.dumps(mcp_client, indent=2) + except api_call.ApiError as ex: + logger.error("Unable to get MCP Client: %s", ex) + return {} + + +def get_mcp(force: bool = False) -> list[dict]: + """Get MCP configs from API Server""" + if force or "mcp_configs" not in state or not state.mcp_configs: + logger.info("Refreshing state.mcp_configs") + endpoints = { + "tools": "v1/mcp/tools", + "prompts": "v1/mcp/prompts", + "resources": "v1/mcp/resources", + } + results = {} + + for key, endpoint in endpoints.items(): + try: + results[key] = api_call.get(endpoint=endpoint) + except api_call.ApiError as ex: + logger.error("Unable to get %s: %s", key, ex) + results[key] = {} + + state.mcp_configs = results + + +def extract_servers() -> list: + """Get a list of distinct MCP servers (by prefix)""" + prefixes = set() + + for _, items in state.mcp_configs.items(): + for item in items or []: # handle None safely + name = item.get("name") + if name and "_" in name: + prefix = name.split("_", 1)[0] + prefixes.add(prefix) + + mcp_servers = sorted(prefixes) + + if "optimizer" in mcp_servers: + mcp_servers.remove("optimizer") + mcp_servers.insert(0, "optimizer") + + return mcp_servers + + +@st.dialog(title="Details", width="large") +def mcp_details(mcp_server: str, mcp_type: str, mcp_name: str) -> None: + """MCP Dialog Box""" + st.header(f"{mcp_name} - MCP server: {mcp_server}") + config = next((t for t in state.mcp_configs[mcp_type] if t.get("name") == f"{mcp_server}_{mcp_name}"), None) + if config.get("description"): + st.code(config["description"], wrap_lines=True, height="content") + if config.get("inputSchema"): + st.subheader("inputSchema", divider="red") + properties = config["inputSchema"].get("properties", {}) + required_fields = set(config["inputSchema"].get("required", [])) + for name, prop in properties.items(): + req = '(required)' if name in required_fields else "" + html = f""" +

{name} {req}

+ + """ + st.html(html) + if config.get("outputSchema"): + st.subheader("outputSchema", divider="red") + if config.get("arguments"): + st.subheader("arguments", divider="red") + if config.get("annotations"): + st.subheader("annotations", divider="red") + if config.get("meta"): + st.subheader("meta", divider="red") + + +def render_configs(mcp_server: str, mcp_type: str, configs: list) -> None: + """Render rows of the MCP type""" + data_col_widths = [0.8, 0.2] + table_col_format = st.columns(data_col_widths, vertical_alignment="center") + col1, col2 = table_col_format + col1.markdown("Name", unsafe_allow_html=True) + col2.markdown("​") + for mcp_name in configs: + col1.text_input( + "Name", + value=mcp_name, + label_visibility="collapsed", + disabled=True, + ) + col2.button( + "Details", + on_click=mcp_details, + key=f"{mcp_server}_{mcp_name}_details", + kwargs=dict(mcp_server=mcp_server, mcp_type=mcp_type, mcp_name=mcp_name), + ) + + +############################################################################# +# MAIN +############################################################################# +def display_mcp() -> None: + """Streamlit GUI""" + st.header("Model Context Protocol", divider="red") + try: + get_mcp() + except api_call.ApiError: + st.stop() + mcp_status = get_mcp_status() + if mcp_status.get("status") == "ready": + st.markdown(f""" + The {mcp_status["name"]} is running. + **Version**: {mcp_status["version"]} + """) + with st.expander("Client Configuration"): + st.code(get_mcp_client(), language="json") + else: + st.error("MCP Server is not running!", icon="🛑") + st.stop() + + selected_mcp_server = st.selectbox( + "Configured MCP Server(s):", + options=extract_servers(), + # index=list(database_lookup.keys()).index(state.client_settings["database"]["alias"]), + key="selected_mcp_server", + # on_change=st_common.update_client_settings("database"), + ) + if state.mcp_configs["tools"]: + tools_lookup = st_common.state_configs_lookup("mcp_configs", "name", "tools") + mcp_tools = [key.split("_", 1)[1] for key in tools_lookup if key.startswith(f"{selected_mcp_server}_")] + if mcp_tools: + st.subheader("Tools", divider="red") + render_configs(selected_mcp_server, "tools", mcp_tools) + if state.mcp_configs["prompts"]: + prompts_lookup = st_common.state_configs_lookup("mcp_configs", "name", "prompts") + mcp_prompts = [key.split("_", 1)[1] for key in prompts_lookup if key.startswith(f"{selected_mcp_server}_")] + if mcp_prompts: + st.subheader("Prompts", divider="red") + render_configs(selected_mcp_server, "prompts", mcp_prompts) + if state.mcp_configs["resources"]: + st.subheader("Resources", divider="red") + resources_lookup = st_common.state_configs_lookup("mcp_configs", "name", "resources") + mcp_resources = [key.split("_", 1)[1] for key in resources_lookup if key.startswith(f"{selected_mcp_server}_")] + if mcp_resources: + st.subheader("Resources", divider="red") + render_configs(selected_mcp_server, "resources", mcp_resources) + + +if __name__ == "__main__": + display_mcp() diff --git a/src/client/content/config/tabs/mcp_bak.py b/src/client/content/config/tabs/mcp_bak.py new file mode 100644 index 00000000..89450432 --- /dev/null +++ b/src/client/content/config/tabs/mcp_bak.py @@ -0,0 +1,24 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" + +from client.mcp.frontend import display_commands_tab, display_ide_tab, get_fastapi_base_url, get_server_capabilities + +import streamlit as st + +def display_mcp(): + fastapi_base_url = get_fastapi_base_url() + tools, resources, prompts = get_server_capabilities(fastapi_base_url) + if "chat_history" not in st.session_state: + st.session_state.chat_history = [] + + + ide, commands = st.tabs(["🛠️ IDE", "📚 Available Commands"]) + + with ide: + # Display the IDE tab using the original AI Optimizer logic. + display_ide_tab() + with commands: + # Display the commands tab using the original AI Optimizer logic. + display_commands_tab(tools, resources, prompts) diff --git a/src/client/mcp/client.py b/src/client/mcp/client.py new file mode 100644 index 00000000..d4282828 --- /dev/null +++ b/src/client/mcp/client.py @@ -0,0 +1,446 @@ +import json +import os +import time +import asyncio +from dotenv import load_dotenv +from mcp import ClientSession, StdioServerParameters, types +from mcp.client.stdio import stdio_client +from typing import List, Dict, Optional, Tuple, Type, Any +from contextlib import AsyncExitStack + +# --- MODIFICATION: Import LangChain components --- +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage +from langchain_core.language_models.base import BaseLanguageModel +from pydantic import create_model, BaseModel, Field +# Import the specific chat models you want to support +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_cohere import ChatCohere +from langchain_ollama import ChatOllama +from langchain_groq import ChatGroq +from langchain_mistralai import ChatMistralAI + +load_dotenv() + +if os.getenv("IS_STREAMLIT_CONTEXT"): + import nest_asyncio + nest_asyncio.apply() + +class MCPClient: + # MODIFICATION: Changed the constructor to accept client_settings + def __init__(self, client_settings: Dict): + """ + Initialize MCP Client using a settings dictionary from the Streamlit client. + + Args: + client_settings: The state.client_settings object. + """ + # 1. Validate the incoming settings dictionary + if not client_settings or 'll_model' not in client_settings: + raise ValueError("Client settings are incomplete. 'll_model' is required.") + + # 2. Store the settings and extract the model ID + self.model_settings = client_settings['ll_model'] + + # This is our new "Service Factory" using LangChain classes + # If no model is specified, we'll initialize with a default one + if 'model' not in self.model_settings or not self.model_settings['model']: + # Set a default model if none is specified + self.model_settings['model'] = 'llama3.1' + # Remove any OpenAI-specific parameters that might cause issues + self.model_settings.pop('openai_api_key', None) + + self.langchain_model = self._create_langchain_model(**self.model_settings) + + self.exit_stack = AsyncExitStack() + self.sessions: Dict[str, ClientSession] = {} + self.tool_to_session: Dict[str, Tuple[ClientSession, types.Tool]] = {} + self.available_prompts: Dict[str, types.Prompt] = {} + self.static_resources: Dict[str, str] = {} + self.dynamic_resources: List[str] = [] + self.resource_to_session: Dict[str, str] = {} + self.prompt_to_session: Dict[str, str] = {} + self.available_tools: List[Dict] = [] + self._stdio_generators: Dict[str, Any] = {} # To store stdio generators for cleanup + print(f"Initialized MCPClient with LangChain model: {self.langchain_model.__class__.__name__}") + + # --- FIX: Add __aenter__ and __aexit__ to make this a context manager --- + async def __aenter__(self): + """Enter the async context, connecting to all servers.""" + await self.connect_to_servers() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the async context, ensuring all connections are cleaned up.""" + await self.cleanup() + + def _create_langchain_model(self, model: str, **kwargs) -> BaseLanguageModel: + """Factory to create and return a LangChain ChatModel instance.""" + # If no model is specified, default to llama3.1 which works with Ollama + if not model: + model = "llama3.1" + # Remove any OpenAI-specific parameters that might cause issues + kwargs.pop('openai_api_key', None) + + model_lower = model.lower() + + # Handle OpenAI models + if model_lower.startswith('gpt-'): + # Check if api_key is in kwargs and rename it to openai_api_key for ChatOpenAI + if 'api_key' in kwargs: + kwargs['openai_api_key'] = kwargs.pop('api_key') + # Remove parameters that shouldn't be passed to ChatOpenAI + kwargs.pop('context_length', None) + kwargs.pop('chat_history', None) + return ChatOpenAI(model=model, **kwargs) + + # Handle Anthropic models + elif model_lower.startswith('claude-'): + kwargs.pop('openai_api_key', None) + return ChatAnthropic(model=model, **kwargs) + + # Handle Google models + elif model_lower.startswith('gemini-'): + kwargs.pop('openai_api_key', None) + return ChatGoogleGenerativeAI(model=model, **kwargs) + + # Handle Mistral models + elif model_lower.startswith('mistral-'): + kwargs.pop('openai_api_key', None) + return ChatMistralAI(model=model, **kwargs) + + # Handle Cohere models + elif model_lower.startswith('cohere-'): + kwargs.pop('openai_api_key', None) + return ChatCohere(model=model, **kwargs) + + # Handle Groq models + elif model_lower.startswith('groq-'): + kwargs.pop('openai_api_key', None) + return ChatGroq(model=model, **kwargs) + + # Default to Ollama for any other model name + else: + return ChatOllama(model=model, **kwargs) + + def _convert_dict_to_langchain_messages(self, message_history: List[Dict]) -> List[BaseMessage]: + """Converts a list of message dictionaries to a list of LangChain message objects.""" + messages: List[BaseMessage] = [] + for msg in message_history: + role = msg.get("role") + content = msg.get("content", "") + if role == "user": + messages.append(HumanMessage(content=content)) # type: ignore + elif role == "assistant": + # AIMessage can handle tool calls directly from the dictionary format + tool_calls = msg.get("tool_calls") + messages.append(AIMessage(content=content, tool_calls=tool_calls or [])) # type: ignore + elif role == "system": + messages.append(SystemMessage(content=content)) # type: ignore + elif role == "tool": + messages.append(ToolMessage(content=content, tool_call_id=msg.get("tool_call_id", ""))) # type: ignore + return messages # type: ignore + + def _convert_langchain_messages_to_dict(self, langchain_messages: List[BaseMessage]) -> List[Dict]: + """Converts a list of LangChain message objects back to a list of dictionaries for session state.""" + dict_messages = [] + for msg in langchain_messages: + if isinstance(msg, HumanMessage): + dict_messages.append({"role": "user", "content": msg.content}) + elif isinstance(msg, AIMessage): + # Preserve tool calls in the dictionary format + dict_messages.append({"role": "assistant", "content": msg.content, "tool_calls": msg.tool_calls}) + elif isinstance(msg, SystemMessage): + dict_messages.append({"role": "system", "content": msg.content}) + elif isinstance(msg, ToolMessage): + dict_messages.append({"role": "tool", "content": msg.content, "tool_call_id": msg.tool_call_id}) + return dict_messages + + def _prepare_messages_for_service(self, message_history: List[Dict]) -> List[Dict]: + """ + FIX: Translates the rich message history from the GUI into a simple, + text-only format that AI services can understand. + """ + prepared_messages = [] + for msg in message_history: + content = msg.get("content") + # If content is a list (multimodal), extract only the text. + if isinstance(content, list): + text_content = " ".join( + part["text"] for part in content if part.get("type") == "text" + ) + prepared_messages.append({"role": msg["role"], "content": text_content}) + # Otherwise, use the content as is (assuming it's a string). + else: + prepared_messages.append(msg) + return prepared_messages + + async def connect_to_servers(self): + try: + config_paths = ["server/mcp/server_config.json", os.path.join(os.path.dirname(__file__), "..", "..", "server", "mcp", "server_config.json")] + servers = {} + for config_path in config_paths: + try: + with open(config_path, "r") as file: + servers = json.load(file).get("mcpServers", {}) + print(f"Loaded MCP server configuration from: {config_path}") + print(f"Found servers: {list(servers.keys())}") + break + except FileNotFoundError: + print(f"MCP server config not found at: {config_path}") + continue + except Exception as e: + print(f"Error reading MCP server config from {config_path}: {e}") + continue + if not servers: + print("No MCP server configuration found!") + for name, config in servers.items(): + print(f"Connecting to MCP server: {name}") + await self.connect_to_server(name, config) + except Exception as e: print(f"Error loading server configuration: {e}") + + async def connect_to_server(self, server_name: str, server_config: dict): + try: + print(f"Connecting to server '{server_name}' with config: {server_config}") + server_params = StdioServerParameters(**server_config) + + # Create the stdio client connection using the exit stack for proper cleanup + try: + read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) + + # Create the client session using the exit stack for proper cleanup + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + + await session.initialize() + self.sessions[server_name] = session + + # Load tools, resources, and prompts from this server + await self._load_server_capabilities(session, server_name) + except RuntimeError as e: + # Handle runtime errors related to task context + if "cancel scope" not in str(e).lower(): + raise + print(f"Warning: Connection to '{server_name}' had context issues: {e}") + except Exception as e: + raise + except Exception as e: + print(f"Failed to connect to '{server_name}': {e}") + import traceback + traceback.print_exc() + + async def _run_async_generator(self, generator): + """Helper method to run an async generator in the current task context.""" + return await generator.__anext__() + + async def _load_server_capabilities(self, session: ClientSession, server_name: str): + """Load tools, resources, and prompts from a connected server.""" + try: + # List tools + tools_list = await session.list_tools() + print(f"Found {len(tools_list.tools)} tools from server '{server_name}'") + for tool in tools_list.tools: + self.tool_to_session[tool.name] = (session, tool) + print(f"Loaded tool '{tool.name}' from server '{server_name}'") + + # List resources + try: + resp = await session.list_resources() + if resp.resources: print(f" - Found Static Resources: {[r.name for r in resp.resources]}") + for resource in resp.resources: + uri = resource.uri.encoded_string() + self.resource_to_session[uri] = server_name + user_shortcut = uri.split('//')[-1] + self.static_resources[user_shortcut] = uri + if resource.name and resource.name != user_shortcut: + self.static_resources[resource.name] = uri + except Exception as e: + print(f"Failed to load resources from server '{server_name}': {e}") + + # Discover DYNAMIC resource templates + try: + # The response object for templates has a `.templates` attribute + resp = await session.list_resource_templates() + if resp.resourceTemplates: print(f" - Found Dynamic Resource Templates: {[t.name for t in resp.resourceTemplates]}") + for template in resp.resourceTemplates: + uri = template.uriTemplate + # The key for the session map MUST be the pattern itself. + self.resource_to_session[uri] = server_name + if uri not in self.dynamic_resources: + self.dynamic_resources.append(uri) + except Exception as e: + # This is also okay, some servers don't have dynamic resources. + print(f"Failed to load dynamic resources from server '{server_name}': {e}") + + + # List prompts + try: + prompts_list = await session.list_prompts() + print(f"Found {len(prompts_list.prompts)} prompts from server '{server_name}'") + for prompt in prompts_list.prompts: + self.available_prompts[prompt.name] = prompt + self.prompt_to_session[prompt.name] = server_name + print(f"Loaded prompt '{prompt.name}' from server '{server_name}'") + except Exception as e: + print(f"Failed to load prompts from server '{server_name}': {e}") + + except Exception as e: + print(f"Failed to load capabilities from server '{server_name}': {e}") + + async def _rebuild_mcp_tool_schemas(self): + """Rebuilds the list of tools from connected MCP servers in a LangChain-compatible format.""" + self.available_tools = [] + for _, (_, tool_object) in self.tool_to_session.items(): + # LangChain's .bind_tools can often work directly with this MCP schema + tool_schema = { + "name": tool_object.name, + "description": tool_object.description, + "args_schema": self.create_pydantic_model_from_schema(tool_object.name, tool_object.inputSchema) + } + self.available_tools.append(tool_schema) + print(f"Available tools after rebuild: {len(self.available_tools)}") + + def create_pydantic_model_from_schema(self, name: str, schema: dict) -> Type[BaseModel]: + """Dynamically creates a Pydantic model from a JSON schema for LangChain tool binding.""" + fields = {} + if schema and 'properties' in schema: + for prop_name, prop_details in schema['properties'].items(): + field_type = str # Default to string + # A more robust implementation would map JSON schema types to Python types + if prop_details.get('type') == 'integer': field_type = int + elif prop_details.get('type') == 'number': field_type = float + elif prop_details.get('type') == 'boolean': field_type = bool + + fields[prop_name] = (field_type, Field(..., description=prop_details.get('description'))) + + return create_model(name, **fields) # type: ignore + + async def execute_mcp_tool(self, tool_name: str, tool_args: Dict) -> str: + try: + session, _ = self.tool_to_session[tool_name] + result = await session.call_tool(tool_name, arguments=tool_args) + if not result.content: return "Tool executed successfully." + + # Handle different content types properly + if isinstance(result.content, list): + text_parts = [] + for item in result.content: + # Check if item has a text attribute + if hasattr(item, 'text'): + text_parts.append(str(item.text)) + else: + # Handle other content types + text_parts.append(str(item)) + return " | ".join(text_parts) + else: + return str(result.content) + except Exception as e: + # Check if it's a closed resource error + if "ClosedResourceError" in str(type(e)) or "closed" in str(e).lower(): + raise Exception("MCP session is closed. Please try again.") from e + else: + raise + + async def invoke(self, message_history: List[Dict]) -> Tuple[str, List[Dict], List[Dict]]: + """ + Main entry point. Now returns a tuple of: + (final_text_response, tool_calls_trace, new_full_history) + """ + max_retries = 3 + for attempt in range(max_retries): + try: + langchain_messages = self._convert_dict_to_langchain_messages(message_history) + + # Separate the final text response from the tool trace + final_text_response = "" + tool_calls_trace = [] + + max_iterations = 10 + tool_execution_failed = False + for iteration in range(max_iterations): + await self._rebuild_mcp_tool_schemas() + model_with_tools = self.langchain_model.bind_tools(self.available_tools) + response_message: AIMessage = await model_with_tools.ainvoke(langchain_messages) + langchain_messages.append(response_message) + + # Capture the final text response from the last message + if response_message.content: + final_text_response = response_message.content + + if not response_message.tool_calls: + break + + for tool_call in response_message.tool_calls: + tool_name = tool_call['name'] + tool_args = tool_call['args'] + + try: + result_content = await self.execute_mcp_tool(tool_name, tool_args) + tool_calls_trace.append({ + "name": tool_name, + "args": tool_args, + "result": result_content + }) + except Exception as e: + if "MCP session is closed" in str(e) and attempt < max_retries - 1: + print(f"MCP session closed, reinitializing (attempt {attempt + 1})") + await self.cleanup(); await self.connect_to_servers() + await asyncio.sleep(0.1); tool_execution_failed = True; break + else: + result_content = f"Error executing tool {tool_name}: {e}" + tool_calls_trace.append({ + "name": tool_name, + "args": tool_args, + "error": result_content + }) + + langchain_messages.append(ToolMessage(content=result_content, tool_call_id=tool_call['id'])) + + if tool_execution_failed: break + + if tool_execution_failed and attempt < max_retries - 1: continue + + final_history_dict = self._convert_langchain_messages_to_dict(langchain_messages) + + return final_text_response, tool_calls_trace, final_history_dict + + except RuntimeError as e: + if "Event loop is closed" in str(e) and attempt < max_retries - 1: + print(f"Event loop closed, reinitializing model (attempt {attempt + 1})") + self.langchain_model = self._create_langchain_model(**self.model_settings) + await asyncio.sleep(0.1); continue + else: raise Exception("Event loop closed. Please try again.") from e + except Exception as e: + if attempt >= max_retries - 1: raise + print(f"Invoke attempt {attempt + 1} failed, retrying: {e}") + await asyncio.sleep(0.1) + + raise Exception("Failed to invoke MCP client after all retries") + + async def cleanup(self): + """Clean up all resources properly.""" + try: + # Close all sessions using the exit stack to avoid context issues + await self.exit_stack.aclose() + except Exception as e: + # Suppress errors related to async context management as they don't affect functionality + if "cancel scope" not in str(e).lower() and "asyncio" not in str(e).lower(): + print(f"Error during cleanup: {e}") + + try: + # Clear sessions + self.sessions.clear() + + # Clear other data structures + self.tool_to_session.clear() + self.available_prompts.clear() + self.static_resources.clear() + self.dynamic_resources.clear() + self.resource_to_session.clear() + self.prompt_to_session.clear() + self.available_tools.clear() + + # Recreate the exit stack for future use + self.exit_stack = AsyncExitStack() + except Exception as e: + print(f"Error during cleanup: {e}") diff --git a/src/client/mcp/frontend.py b/src/client/mcp/frontend.py new file mode 100644 index 00000000..2c645129 --- /dev/null +++ b/src/client/mcp/frontend.py @@ -0,0 +1,60 @@ +import streamlit as st +import os +import requests +import json + +def set_page(): + st.set_page_config( + page_title="MCP Universal Chatbot", + page_icon="🤖", + layout="wide" + ) + +def get_fastapi_base_url(): + return os.getenv("FASTAPI_BASE_URL", "http://127.0.0.1:8000") + + + +def get_server_files(): + files = ["server/mcp/server_config.json"] + try: + with open("server/mcp/server_config.json", "r") as f: config = json.load(f) + for server in config.get("mcpServers", {}).values(): + script_path = server.get("args", [None])[0] + if script_path and os.path.exists(script_path): files.append(script_path) + except FileNotFoundError: st.sidebar.error("server_config.json not found!") + return list(set(files)) + +def display_ide_tab(): + st.header("🔧 Integrated MCP Server IDE") + st.info("Edit your server configuration or scripts. Restart the launcher for changes to take effect.") + server_files = get_server_files() + selected_file = st.selectbox("Select a file to edit", options=server_files) + if selected_file: + with open(selected_file, "r") as f: file_content = f.read() + from streamlit_ace import st_ace + new_content = st_ace(value=file_content, language="python" if selected_file.endswith(".py") else "json", theme="monokai", keybinding="vscode", height=500, auto_update=True) + if st.button("Save Changes"): + with open(selected_file, "w") as f: f.write(new_content) + st.success(f"Successfully saved {selected_file}!") + +def display_commands_tab(tools, resources, prompts): + st.header("📖 Discovered MCP Commands") + st.info("These commands were discovered from the MCP backend.") + + if tools: + with st.expander("🛠️ Available Tools (Used automatically by the AI)", expanded=True): + # Extract just the tool names from the tools response + if "tools" in tools and isinstance(tools["tools"], list): + tool_names = [tool.get("name", tool) if isinstance(tool, dict) else tool for tool in tools["tools"]] + st.write(tool_names) + else: + st.json(tools) + + if resources: + with st.expander("📦 Available Resources (Use with `@` or just ``)"): + st.json(resources) + + if prompts: + with st.expander("📝 Available Prompts (Use with `/prompt ` or select in chat)"): + st.json(prompts) diff --git a/src/client/utils/st_common.py b/src/client/utils/st_common.py index 675248ef..218e313d 100644 --- a/src/client/utils/st_common.py +++ b/src/client/utils/st_common.py @@ -15,6 +15,12 @@ from common import logging_config, help_text from common.schema import PromptPromptType, PromptNameType, SelectAISettings +# Import the MCP initialization function +try: + from launch_server import initialize_mcp_engine_with_model +except ImportError: + initialize_mcp_engine_with_model = None + logger = logging_config.logging.getLogger("client.utils.st_common") @@ -27,9 +33,11 @@ def clear_state_key(state_key: str) -> None: logger.debug("State cleared: %s", state_key) -def state_configs_lookup(state_configs_name: str, key: str) -> dict[str, dict[str, Any]]: +def state_configs_lookup(state_configs_name: str, key: str, section: str = None) -> dict[str, dict[str, Any]]: """Convert state. into a lookup based on key""" configs = getattr(state, state_configs_name) + if section: + configs = configs.get(section, []) return {config[key]: config for config in configs if key in config} @@ -164,6 +172,8 @@ def ll_sidebar() -> None: selected_model = state.client_settings["ll_model"]["model"] ll_idx = list(ll_models_enabled.keys()).index(selected_model) if not state.client_settings["selectai"]["enabled"]: + # Store the previous model to detect changes + previous_model = selected_model selected_model = st.sidebar.selectbox( "Chat model:", options=list(ll_models_enabled.keys()), @@ -173,6 +183,18 @@ def ll_sidebar() -> None: disabled=state.client_settings["selectai"]["enabled"], ) + # If the model has changed, reinitialize the MCP engine + if selected_model != previous_model and initialize_mcp_engine_with_model: + try: + # Instead of creating a new event loop, we'll set a flag to indicate + # that the MCP engine needs to be reinitialized + state.mcp_needs_reinit = selected_model + logger.info("MCP engine marked for reinitialization with model: %s", selected_model) + except Exception as ex: + logger.error( + "Failed to mark MCP engine for reinitialization with model %s: %s", selected_model, str(ex) + ) + # Temperature temperature = ll_models_enabled[selected_model]["temperature"] user_temperature = state.client_settings["ll_model"]["temperature"] diff --git a/src/client/utils/st_footer.py b/src/client/utils/st_footer.py index b8d5b643..8314171e 100644 --- a/src/client/utils/st_footer.py +++ b/src/client/utils/st_footer.py @@ -65,25 +65,7 @@ def _inject_footer(selector, insertion_method, footer_html, cleanup_styles=True) """ components.html(js_code, height=0) - -# --- FUNCTION 1: The Cleanup Crew --- -def remove_footer(): - """ - Injects simple JavaScript to find and remove any existing footer. - This MUST be called at the TOP of every page in your app. - """ - js_code = """ - - """ - components.html(js_code, height=0) - - -# --- FUNCTION 2: The Chat Page Footer --- +# --- The Chat Page Footer --- def render_chat_footer(): """ Standardized footer for chat pages. @@ -97,22 +79,3 @@ def render_chat_footer(): _inject_footer( selector='[data-testid="stBottomBlockContainer"]', insertion_method="afterend", footer_html=footer_html ) - - -# --- FUNCTION 3: The Models Page Footer --- -def render_models_footer(): - """ - Standardized footer for models pages. - """ - footer_html = f""" - {FOOTER_STYLE} - - """ - _inject_footer( - selector='[data-testid="stAppIframeResizerAnchor"]', - insertion_method="beforebegin", - footer_html=footer_html, - cleanup_styles=False, - ) diff --git a/src/common/schema.py b/src/common/schema.py index 5cc42fcf..164dc258 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -108,7 +108,9 @@ class DatabaseVectorStorage(BaseModel): """Database Vector Storage Tables""" vector_store: Optional[str] = Field( - default=None, description="Vector Store Table Name (auto-generated, do not set)", readOnly=True + default=None, + description="Vector Store Table Name (auto-generated, do not set)", + json_schema_extra={"readOnly": True}, ) alias: Optional[str] = Field(default=None, description="Identifiable Alias") model: Optional[str] = Field(default=None, description="Embedding Model") @@ -121,8 +123,8 @@ class DatabaseVectorStorage(BaseModel): class DatabaseSelectAIObjects(BaseModel): """Database SelectAI Objects""" - owner: Optional[str] = Field(default=None, description="Object Owner", readOnly=True) - name: Optional[str] = Field(default=None, description="Object Name", readOnly=True) + owner: Optional[str] = Field(default=None, description="Object Owner", json_schema_extra={"readOnly": True}) + name: Optional[str] = Field(default=None, description="Object Name", json_schema_extra={"readOnly": True}) enabled: bool = Field(default=False, description="SelectAI Enabled") @@ -144,12 +146,14 @@ class Database(DatabaseAuth): """Database Object""" name: str = Field(default="DEFAULT", description="Name of Database (Alias)") - connected: bool = Field(default=False, description="Connection Established", readOnly=True) + connected: bool = Field(default=False, description="Connection Established", json_schema_extra={"readOnly": True}) vector_stores: Optional[list[DatabaseVectorStorage]] = Field( - default=[], description="Vector Storage (read-only)", readOnly=True + default=[], description="Vector Storage (read-only)", json_schema_extra={"readOnly": True} ) selectai: bool = Field(default=False, description="SelectAI Possible") - selectai_profiles: Optional[list] = Field(default=[], description="SelectAI Profiles (read-only)", readOnly=True) + selectai_profiles: Optional[list] = Field( + default=[], description="SelectAI Profiles (read-only)", json_schema_extra={"readOnly": True} + ) # Do not expose the connection to the endpoint _connection: oracledb.Connection = PrivateAttr(default=None) @@ -163,6 +167,40 @@ def set_connection(self, connection: oracledb.Connection) -> None: self._connection = connection +##################################################### +# MCP +##################################################### +class MCPModelConfig(BaseModel): + """MCP Model Configuration""" + + model_id: str = Field(..., description="Model identifier") + service_type: Literal["ollama", "openai"] = Field(..., description="AI service type") + base_url: str = Field(default="http://localhost:11434", description="Base URL for API") + api_key: Optional[str] = Field(default=None, description="API key", json_schema_extra={"sensitive": True}) + enabled: bool = Field(default=True, description="Model availability status") + streaming: bool = Field(default=False, description="Enable streaming responses") + temperature: float = Field(default=1.0, description="Model temperature") + max_tokens: int = Field(default=2048, description="Maximum tokens per response") + + +class MCPToolConfig(BaseModel): + """MCP Tool Configuration""" + + name: str = Field(..., description="Tool name") + description: str = Field(..., description="Tool description") + parameters: dict[str, Any] = Field(..., description="Tool parameters") + enabled: bool = Field(default=True, description="Tool availability status") + + +class MCPSettings(BaseModel): + """MCP Global Settings""" + + models: list[MCPModelConfig] = Field(default_factory=list, description="Available MCP models") + tools: list[MCPToolConfig] = Field(default_factory=list, description="Available MCP tools") + default_model: Optional[str] = Field(default=None, description="Default model identifier") + enabled: bool = Field(default=True, description="Enable or disable MCP functionality") + + ##################################################### # Models ##################################################### @@ -225,6 +263,16 @@ def check_provider(self): raise ValueError(f"Provider '{self.provider}' is not valid. Must be one of: {providers}") return self + def check_provider_matches_type(self): + """Validate valid API""" + providers = get_args(ModelProviders) + if not self.provider or self.provider == "unset": + return self + + if self.provider not in providers: + raise ValueError(f"Provider '{self.provider}' is not valid. Must be one of: {providers}") + return self + ##################################################### # Oracle Cloud Infrastructure @@ -239,7 +287,9 @@ class OracleCloudSettings(BaseModel): """Store Oracle Cloud Infrastructure Settings""" auth_profile: str = Field(default="DEFAULT", description="Config File Profile") - namespace: Optional[str] = Field(default=None, description="Object Store Namespace", readOnly=True) + namespace: Optional[str] = Field( + default=None, description="Object Store Namespace", json_schema_extra={"readOnly": True} + ) user: Optional[str] = Field( default=None, description="Optional if using Auth Token", @@ -380,6 +430,7 @@ class Configuration(BaseModel): model_configs: Optional[list[Model]] = None oci_configs: Optional[list[OracleCloudSettings]] = None prompt_configs: Optional[list[Prompt]] = None + mcp_configs: Optional[list[MCPModelConfig]] = Field(default=None, description="List of MCP configurations") def model_dump_public(self, incl_sensitive: bool = False, incl_readonly: bool = False) -> dict: """Remove marked fields for FastAPI Response""" @@ -502,3 +553,6 @@ class EvaluationReport(Evaluation): TestSetsIdType = TestSets.__annotations__["tid"] TestSetsNameType = TestSets.__annotations__["name"] TestSetDateType = TestSets.__annotations__["created"] +MCPModelIdType = MCPModelConfig.__annotations__["model_id"] +MCPServiceType = MCPModelConfig.__annotations__["service_type"] +MCPToolNameType = MCPToolConfig.__annotations__["name"] diff --git a/src/hello_world.py b/src/hello_world.py new file mode 100644 index 00000000..2028682b --- /dev/null +++ b/src/hello_world.py @@ -0,0 +1,81 @@ +import asyncio +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client + +from langgraph.prebuilt import create_react_agent +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + + +client = MultiServerMCPClient( + { + "optimizer": { + "transport": "streamable_http", + "url": "http://localhost:8000/mcp/", + "headers": {"Authorization": "Bearer demo_api_key"}, + } + } +) +async def call_tool(name: str): + tools = await client.get_tools() + agent = create_react_agent("openai:gpt-4o-mini", tools) + math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"}) + print(math_response) + +# async def call_tool(name: str): +# async with client.session("optimizer") as session: + # tools = await load_mcp_tools(session) + # agent = create_react_agent("openai:gpt-4o-mini", tools) + # # math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"}) + # # weather_response = await agent.ainvoke({"messages": "what is the weather in nyc?"}) + # database_response = await agent.ainvoke({"messages": "connect to OPTIMIZER_DEFAULT"}) + # database_response = await agent.ainvoke({"messages": "show me a list of table names"}) + # print(database_response) + # # print(weather_response) + +asyncio.run(call_tool("Ford")) + +# async def call_tool(name: str): +# async with streamablehttp_client(config) as (read, write, _): +# async with ClientSession(read, write) as session: +# # Initialize the connection +# await session.initialize() + +# # Get tools +# tools = await load_mcp_tools(session) +# agent = create_react_agent("openai:gpt-4o-mini", tools) +# math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"}) +# print(math_response) + + +# asyncio.run(call_tool("Ford")) + +# client = Client(config) +# async def call_tool(name: str): +# async with client: +# print(f"Connected: {client.is_connected()}") +# tools = await client.load_mcp_tools(client) +# # agent = create_react_agent("openai:gpt-4o-mini", tools) +# # math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"}) +# # print(math_response) +# # result = await client.call_tool("optimizer_greet", {"name": name}) +# # print(result) +# # result = await client.call_tool("optimizer_multiply", {"a": 5, "b": 3}) +# # print(result) + + +# from mcp import ClientSession +# from mcp.client.streamable_http import streamablehttp_client + +# from langgraph.prebuilt import create_react_agent +# from langchain_mcp_adapters.tools import load_mcp_tools + +# async with streamablehttp_client("http://localhost:8000/mcp/") as (read, write, _): +# async with ClientSession(read, write) as session: +# # Initialize the connection +# await session.initialize() + +# # Get tools +# tools = await load_mcp_tools(session) +# agent = create_react_agent("openai:gpt-4.1", tools) +# math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"}) diff --git a/src/launch_client.py b/src/launch_client.py index d4051512..60af0451 100644 --- a/src/launch_client.py +++ b/src/launch_client.py @@ -91,7 +91,7 @@ def main() -> None: } .stAppHeader img[alt="Logo"] { width: 50%; - } + } """, ) @@ -131,6 +131,7 @@ def main() -> None: state.disabled["model_cfg"] = os.environ.get("DISABLE_MODEL_CFG", "false").lower() == "true" state.disabled["oci_cfg"] = os.environ.get("DISABLE_OCI_CFG", "false").lower() == "true" state.disabled["settings"] = os.environ.get("DISABLE_SETTINGS", "false").lower() == "true" + state.disabled["mcp_cfg"] = os.environ.get("DISABLE_MCP_CFG", "false").lower() == "true" # Left Hand Side - Navigation chatbot = st.Page("client/content/chatbot.py", title="ChatBot", icon="💬", default=True) diff --git a/src/launch_server.py b/src/launch_server.py index fde83bda..a14d8ca4 100644 --- a/src/launch_server.py +++ b/src/launch_server.py @@ -2,13 +2,17 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -# spell-checker:ignore fastapi laddr checkpointer langgraph litellm -# spell-checker:ignore noauth apiserver configfile selectai giskard ollama llms -# pylint: disable=redefined-outer-name,wrong-import-position +# spell-checker:ignore configfile fastmcp noauth selectai getpid procs litellm giskard ollama +# spell-checker:ignore dotenv apiserver laddr -import os +# Patch litellm for Giskard/Ollama issue +import server.patches.litellm_patch # pylint: disable=unused-import, wrong-import-order +# Set OS Environment before importing other modules # Set OS Environment (Don't move their position to reflect on imports) +# pylint: disable=wrong-import-position +import os + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" os.environ["GSK_DISABLE_SENTRY"] = "true" os.environ["GSK_DISABLE_ANALYTICS"] = "true" @@ -16,37 +20,49 @@ app_home = os.path.dirname(os.path.abspath(__file__)) if "TNS_ADMIN" not in os.environ: os.environ["TNS_ADMIN"] = os.path.join(app_home, "tns_admin") - -# Patch litellm for Giskard/Ollama issue -import server.patches.litellm_patch # pylint: disable=unused-import +# pylint: enable=wrong-import-position import argparse -import queue +import asyncio +from contextlib import asynccontextmanager +from pathlib import Path import secrets import socket import subprocess -import threading +import sys from typing import Annotated from pathlib import Path import uvicorn -import psutil +from fastapi import FastAPI, APIRouter, Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastmcp import FastMCP, settings +from fastmcp.server.auth import StaticTokenVerifier +from langgraph.checkpoint.memory import InMemorySaver -from fastapi import FastAPI, HTTPException, Depends, status, APIRouter + +# Third Party +import psutil +import uvicorn +from fastapi import FastAPI, APIRouter, Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastmcp import FastMCP, settings +from fastmcp.server.auth import StaticTokenVerifier + +# Configuration +from server.bootstrap import configfile # pylint: disable=ungrouped-imports # Logging from common import logging_config from common._version import __version__ -# Configuration -from server.bootstrap import configfile - logger = logging_config.logging.getLogger("launch_server") +# Establish LangGraph Short-Term Memory (thread-level persistence) +graph_memory = InMemorySaver() ########################################## -# Process Control +# Client Process Control ########################################## def start_server(port: int = 8000, logfile: bool = False) -> int: """Start the uvicorn server for FastAPI""" @@ -71,54 +87,33 @@ def get_pid_using_port(port: int) -> int: continue return None - def start_subprocess(port: int, logfile: bool) -> subprocess.Popen: - """Start the uvicorn server as a subprocess.""" - logger.info("API server starting on port: %i", port) - log_file = open(f"apiserver_{port}.log", "a", encoding="utf-8") if logfile else None - stdout = stderr = log_file if logfile else subprocess.PIPE - process = subprocess.Popen( - [ - "uvicorn", - "launch_server:create_app", - "--factory", - "--host", - "0.0.0.0", - "--port", - str(port), - ], - stdout=stdout, - stderr=stderr, - ) - logger.info("API server started on Port: %i; PID: %i", port, process.pid) - return process - port = port or find_available_port() - existing_pid = get_pid_using_port(port) - if existing_pid: + if existing_pid := get_pid_using_port(port): logger.info("API server already running on port: %i (PID: %i)", port, existing_pid) return existing_pid - popen_queue = queue.Queue() - thread = threading.Thread( - target=lambda: popen_queue.put(start_subprocess(port, logfile)), - daemon=True, - ) - thread.start() + client_args = [sys.executable, __file__, "--port", str(port)] + if logfile: + log_file = open(f"apiserver_{port}.log", "a", encoding="utf-8") # pylint: disable=consider-using-with + stdout = stderr = log_file + else: + stdout = stderr = subprocess.PIPE - return popen_queue.get().pid + process = subprocess.Popen(client_args, stdout=stdout, stderr=stderr) # pylint: disable=consider-using-with + logger.info("Server started on port %i with PID %i", port, process.pid) + return process.pid def stop_server(pid: int) -> None: - """Stop the uvicorn server for FastAPI.""" + """Stop the uvicorn server for FastAPI when started via the client""" try: proc = psutil.Process(pid) proc.terminate() proc.wait() + logger.info("API server stopped.") except (psutil.NoSuchProcess, psutil.AccessDenied) as ex: logger.error("Failed to terminate process with PID: %i - %s", pid, ex) - logger.info("API server stopped.") - ########################################## # Server App and API Key @@ -136,68 +131,124 @@ def get_api_key() -> str: return os.getenv("API_SERVER_KEY") -def verify_key( +def fastapi_verify_key( http_auth: Annotated[ HTTPAuthorizationCredentials, Depends(HTTPBearer(description="Please provide API_SERVER_KEY.")), ], ) -> None: - """Verify that the provided API key is correct.""" + """FastAPI: Verify that the provided API key is correct.""" if http_auth.credentials != get_api_key(): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) -def register_endpoints(noauth: APIRouter, auth: APIRouter): +########################################## +# Endpoint Registration +########################################## +async def register_endpoints(mcp: FastMCP, auth: APIRouter, noauth: APIRouter): """Register API Endpoints - Imports to avoid bootstrapping before config file read New endpoints need to be registered in server.api.v1.__init__.py """ - import server.api.v1 as api_v1 # pylint: disable=import-outside-toplevel + logger.debug("Starting Endpoint Registration") + # pylint: disable=import-outside-toplevel + import server.api.v1 as api_v1 + from server.mcp import register_all_mcp # No-Authentication (probes only) noauth.include_router(api_v1.probes.noauth, prefix="/v1", tags=["Probes"]) # Authenticated auth.include_router(api_v1.chat.auth, prefix="/v1/chat", tags=["Chatbot"]) - auth.include_router(api_v1.databases.auth, prefix="/v1/databases", tags=["Config - Databases"]) auth.include_router(api_v1.embed.auth, prefix="/v1/embed", tags=["Embeddings"]) - auth.include_router(api_v1.models.auth, prefix="/v1/models", tags=["Config - Models"]) - auth.include_router(api_v1.oci.auth, prefix="/v1/oci", tags=["Config - Oracle Cloud Infrastructure"]) - auth.include_router(api_v1.prompts.auth, prefix="/v1/prompts", tags=["Tools - Prompts"]) auth.include_router(api_v1.selectai.auth, prefix="/v1/selectai", tags=["SelectAI"]) - auth.include_router(api_v1.settings.auth, prefix="/v1/settings", tags=["Tools - Settings"]) + auth.include_router(api_v1.prompts.auth, prefix="/v1/prompts", tags=["Tools - Prompts"]) auth.include_router(api_v1.testbed.auth, prefix="/v1/testbed", tags=["Tools - Testbed"]) + auth.include_router(api_v1.settings.auth, prefix="/v1/settings", tags=["Config - Settings"]) + auth.include_router(api_v1.databases.auth, prefix="/v1/databases", tags=["Config - Databases"]) + auth.include_router(api_v1.models.auth, prefix="/v1/models", tags=["Config - Models"]) + auth.include_router(api_v1.oci.auth, prefix="/v1/oci", tags=["Config - Oracle Cloud Infrastructure"]) + auth.include_router(api_v1.mcp.auth, prefix="/v1/mcp", tags=["Config - MCP Servers"]) + + # Auto-discover all MCP tools and register HTTP + MCP endpoints + mcp_router = APIRouter(prefix="/mcp", tags=["MCP Tools"]) + await register_all_mcp(mcp, auth) + auth.include_router(mcp_router) + logger.debug("Finished Endpoint Registration") ############################################################################# # APP FACTORY ############################################################################# -def create_app(config: str = None) -> FastAPI: - """Create and configure the FastAPI app.""" +async def create_app(config: str = "") -> FastAPI: + """FastAPI Application Factory""" + if not config: config = configfile.config_file_path() config_file = Path(os.getenv("CONFIG_FILE", config)) configfile.ConfigStore.load_from_file(config_file) - app = FastAPI( + # FastMCP Server + fastmcp_verifier = StaticTokenVerifier( + tokens={get_api_key(): {"client_id": "optimizer", "scopes": ["read", "write"]}} + ) + settings.stateless_http = True + fastmcp_app = FastMCP( + name="Oracle AI Optimizer and Toolkit MCP Server", + version=__version__, + auth=fastmcp_verifier, + include_fastmcp_meta=False, + ) + fastmcp_engine = fastmcp_app.http_app(path="/") + + @asynccontextmanager + async def combined_lifespan(fastapi_app: FastAPI): + """Ensures all MCP Servers are cleaned up""" + async with fastmcp_engine.lifespan(fastapi_app): + yield + # Shutdown cleanup + logger.info("Cleaning up leftover processes...") + parent = psutil.Process(os.getpid()) + children = parent.children(recursive=True) + for p in children: + try: + p.terminate() + except psutil.NoSuchProcess: + continue + # Wait synchronously, outside the event loop + _, still_alive = psutil.wait_procs(children, timeout=3) + for p in still_alive: + try: + p.kill() + except psutil.NoSuchProcess: + continue + + # FastAPI Server + fastapi_app = FastAPI( title="Oracle AI Optimizer and Toolkit", version=__version__, docs_url="/v1/docs", openapi_url="/v1/openapi.json", + lifespan=combined_lifespan, license_info={ "name": "Universal Permissive License", "url": "http://oss.oracle.com/licenses/upl", }, ) + # Store MCP in the app state + fastapi_app.state.fastmcp_app = fastmcp_app + # Register MCP Server into FastAPI + fastapi_app.mount("/mcp", fastmcp_engine) + # Setup Routes and Register non-MCP endpoints noauth = APIRouter() - auth = APIRouter(dependencies=[Depends(verify_key)]) + auth = APIRouter(dependencies=[Depends(fastapi_verify_key)]) - # Register Endpoints - register_endpoints(noauth, auth) - app.include_router(noauth) - app.include_router(auth) + # Register the endpoints + await register_endpoints(fastmcp_app, auth, noauth) + fastapi_app.include_router(noauth) + fastapi_app.include_router(auth) - return app + return fastapi_app if __name__ == "__main__": @@ -209,10 +260,23 @@ def create_app(config: str = None) -> FastAPI: default=configfile.config_file_path(), help="Full path to configuration file (JSON)", ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to start server", + ) args = parser.parse_args() PORT = int(os.getenv("API_SERVER_PORT", "8000")) logger.info("API Server Using port: %i", PORT) - app = create_app(args.config) - uvicorn.run(app, host="0.0.0.0", port=PORT, log_config=logging_config.LOGGING_CONFIG) + # Sync entrypoint, but calls async factory before running Uvicorn + app = asyncio.run(create_app(args.config)) + uvicorn.run( + app, + host="0.0.0.0", + port=PORT, + timeout_graceful_shutdown=5, + log_config=logging_config.LOGGING_CONFIG, + ) diff --git a/src/pyproject.toml b/src/pyproject.toml index fce2dc60..87a33b3c 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -46,7 +46,7 @@ server = [ "langchain-perplexity==0.1.2", "langchain-xai==0.2.5", "langgraph==0.6.6", - "litellm==1.76.1", + "litellm==1.76.1", "llama-index==0.13.3", "lxml==6.0.0", "matplotlib==3.10.6", diff --git a/src/server/Dockerfile b/src/server/Dockerfile index 70992359..c9113274 100644 --- a/src/server/Dockerfile +++ b/src/server/Dockerfile @@ -12,7 +12,7 @@ ENV RUNUSER=oracleai ENV PATH=/opt/.venv/bin:$PATH RUN microdnf --nodocs -y update && \ - microdnf --nodocs -y install python3.11 python3.11-pip && \ + microdnf --nodocs -y install python3.11 python3.11-pip sqlcl && \ microdnf clean all && \ python3.11 -m venv --symlinks --upgrade-deps /opt/.venv && \ groupadd $RUNUSER && \ diff --git a/src/server/agents/tools/selectai.py b/src/server/agents/tools/selectai.py index fb0f40ac..1b24a943 100644 --- a/src/server/agents/tools/selectai.py +++ b/src/server/agents/tools/selectai.py @@ -10,9 +10,10 @@ from langchain_core.tools import BaseTool, tool from langchain_core.runnables import RunnableConfig -from common import logging_config from server.api.utils.databases import execute_sql +from common import logging_config + logger = logging_config.logging.getLogger("server.tools.selectai_executor") # ------------------------------------------------------------------------------ diff --git a/src/server/api/core/bootstrap.py b/src/server/api/core/bootstrap.py index fd970758..5db865e0 100644 --- a/src/server/api/core/bootstrap.py +++ b/src/server/api/core/bootstrap.py @@ -4,7 +4,7 @@ """ # spell-checker:ignore genai -from server.bootstrap import databases, models, oci, prompts, settings +from server.bootstrap import databases, models, oci, prompts, settings, mcp from common import logging_config logger = logging_config.logging.getLogger("api.core.bootstrap") @@ -14,3 +14,4 @@ OCI_OBJECTS = oci.main() PROMPT_OBJECTS = prompts.main() SETTINGS_OBJECTS = settings.main() +MCP_OBJECTS = mcp.main() diff --git a/src/server/api/core/mcp.py b/src/server/api/core/mcp.py new file mode 100644 index 00000000..a0a5d128 --- /dev/null +++ b/src/server/api/core/mcp.py @@ -0,0 +1,59 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" + +# spell-checker:ignore streamable +import os + +# from langchain_mcp_adapters.client import MultiServerMCPClient +# from typing import Optional, List, Dict, Any +# from common.schema import MCPModelConfig, MCPToolConfig, MCPSettings +# from server.bootstrap import mcp as mcp_bootstrap +from common import logging_config + +logger = logging_config.logging.getLogger("api.core.mcp") + + +def get_client(server: str = "http://127.0.0.1", port: int = 8000, client: str = None) -> dict: + """Get the MCP Client Configuration""" + mcp_client = { + "mcpServers": { + "optimizer": { + "type": "streamableHttp", + "transport": "streamable_http", + "url": f"{server}:{port}/mcp/", + "headers": {"Authorization": f"Bearer {os.getenv('API_SERVER_KEY')}"}, + } + } + } + if client == "langgraph": + del mcp_client["mcpServers"]["optimizer"]["type"] + + +# def get_mcp_model(model_id: str) -> Optional[MCPModelConfig]: +# """Get MCP model configuration by ID""" +# for model in mcp_bootstrap.MCP_MODELS: +# if model.model_id == model_id: +# return model +# return None + + +# def get_mcp_tool(tool_name: str) -> Optional[MCPToolConfig]: +# """Get MCP tool configuration by name""" +# for tool in mcp_bootstrap.MCP_TOOLS: +# if tool.name == tool_name: +# return tool +# return None + + +# def update_mcp_settings(settings: Dict[str, Any]) -> MCPSettings: +# """Update MCP settings""" +# if not mcp_bootstrap.MCP_SETTINGS: +# raise ValueError("MCP settings not initialized") + +# for key, value in settings.items(): +# if hasattr(mcp_bootstrap.MCP_SETTINGS, key): +# setattr(mcp_bootstrap.MCP_SETTINGS, key, value) + +# return mcp_bootstrap.MCP_SETTINGS diff --git a/src/server/api/core/models.py b/src/server/api/core/models.py index a8114224..b86f2a8a 100644 --- a/src/server/api/core/models.py +++ b/src/server/api/core/models.py @@ -44,9 +44,13 @@ def get_model( ) -> Union[list[Model], Model, None]: """Used in direct call from list_models and agents.models""" model_objects = bootstrap.MODEL_OBJECTS - - logger.debug("%i models are defined", len(model_objects)) - + logger.debug( + "Filtering %i models for id: %s; type: %s; disabled: %s", + len(model_objects), + model_id, + model_type, + include_disabled + ) model_filtered = [ model for model in model_objects diff --git a/src/server/api/core/settings.py b/src/server/api/core/settings.py index 0f993721..81de678e 100644 --- a/src/server/api/core/settings.py +++ b/src/server/api/core/settings.py @@ -54,11 +54,15 @@ def get_server_config() -> Configuration: prompt_objects = bootstrap.PROMPT_OBJECTS prompt_configs = list(prompt_objects) + # mcp_objects = bootstrap.MCP_OBJECTS + # mcp_configs = list(mcp_objects) + full_config = { "database_configs": database_configs, "model_configs": model_configs, "oci_configs": oci_configs, "prompt_configs": prompt_configs, + # "mcp_configs": mcp_configs, } return full_config @@ -92,6 +96,9 @@ def update_server_config(config_data: dict) -> None: if "prompt_configs" in config_data: bootstrap.PROMPT_OBJECTS = config.prompt_configs or [] + if "mcp_configs" in config_data: + bootstrap.MCP_OBJECTS = config.mcp_configs or [] + def load_config_from_json_data(config_data: dict, client: ClientIdType = None) -> None: """Shared logic for loading settings from JSON data.""" diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 58b71780..39f1b54d 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -21,8 +21,7 @@ from server.api.core.models import UnknownModelError -from common import schema -from common import logging_config +from common import schema, logging_config logger = logging_config.logging.getLogger("api.utils.chat") diff --git a/src/server/api/utils/embed.py b/src/server/api/utils/embed.py index 173774b9..ef8abf75 100644 --- a/src/server/api/utils/embed.py +++ b/src/server/api/utils/embed.py @@ -27,9 +27,7 @@ import server.api.utils.databases as utils_databases -from common import schema, functions - -from common import logging_config +from common import schema, functions, logging_config logger = logging_config.logging.getLogger("api.utils.embed") diff --git a/src/server/api/utils/mcp.py b/src/server/api/utils/mcp.py new file mode 100644 index 00000000..e70f72e5 --- /dev/null +++ b/src/server/api/utils/mcp.py @@ -0,0 +1,140 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker:ignore astream selectai + +import os +import time +from typing import Literal, AsyncGenerator +import json +import oci + +from langchain_core.messages import HumanMessage +from langchain_core.runnables import RunnableConfig +from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.graph.state import CompiledStateGraph + +import server.api.core.settings as core_settings +import server.api.core.oci as core_oci +import server.api.core.prompts as core_prompts +import server.api.utils.models as util_models +import server.api.utils.databases as util_databases +import server.api.utils.selectai as util_selectai +import server.api.core.mcp as core_mcp +import server.mcp.graph as graph + +from common import logging_config, schema + +logger = logging_config.logging.getLogger("api.utils.mcp") + +def get_client(server: str = "http://127.0.0.1", port: int = 8000) -> dict: + """Get the MCP Client Configuration""" + mcp_client = { + "mcpServers": { + "optimizer": { + "type": "streamableHttp", + "transport": "streamable_http", + "url": f"{server}:{port}/mcp/", + "headers": {"Authorization": f"Bearer {os.getenv('API_SERVER_KEY')}"}, + } + } + } + + return mcp_client + +def error_response(call: str, message: str, model: dict) -> dict: + """Send the error as a response""" + response = message + if call != "streams": + response = { + "id": "error", + "choices": [{"message": {"role": "assistant", "content": message}, "index": 0, "finish_reason": "stop"}], + "created": int(time.time()), + "model": model["model"], + "object": "chat.completion", + } + logger.debug("Returning Error Response: %s", response) + return response + + +async def completion_generator( + client: schema.ClientIdType, request: schema.ChatRequest, call: Literal["completions", "streams"] +) -> AsyncGenerator[str, None]: + """MCP Completion Requests""" + client_settings = core_settings.get_client_settings(client) + model = request.model_dump() + logger.debug("Settings: %s", client_settings) + logger.debug("Request: %s", model) + + # Establish LL Model Params (if the request specs a model, otherwise override from settings) + if not model["model"]: + model = client_settings.ll_model.model_dump() + + # Get OCI Settings + oci_config = core_oci.get_oci(client=client) + + # Setup Language Model + ll_model = util_models.get_client(model, oci_config) + if not ll_model: + yield error_response("I'm unable to initialise the Language Model. Please refresh the application.", model) + return + + # Setup MCP and bind tools + mcp_client = MultiServerMCPClient({"optimizer": core_mcp.get_client()["mcpServers"]["optimizer"]}) + tools = await mcp_client.get_tools() + ll_model_with_tools = model.bind_tools(tools) + + # Build our Graph + graph.set_node("tools_node", ToolNode(tools)) + agent: CompiledStateGraph = graph.mcp_graph + # Setup MCP and bind tools + mcp_client = MultiServerMCPClient( + {"optimizer": core_mcp.get_client(client="langgraph")["mcpServers"]["optimizer"]} + ) + tools = await mcp_client.get_tools() + try: + ll_model_with_tools = ll_model.bind_tools(tools) + except NotImplementedError as ex: + yield error_response(call, str(ex), model) + raise + + # Build our Graph + agent: CompiledStateGraph = graph.main(tools) + + kwargs = { + "input": {"messages": [HumanMessage(content=request.messages[0].content)]}, + "config": RunnableConfig( + configurable={"thread_id": client, "ll_model": ll_model_with_tools, "tools": tools}, + + metadata={"use_history": client_settings.ll_model.chat_history}, + ), + } + + yield "End" + + + try: + async for chunk in agent.astream_events(**kwargs, version="v2"): + # The below will produce A LOT of output; uncomment when desperate + # logger.debug("Streamed Chunk: %s", chunk) + if chunk["event"] == "on_chat_model_stream": + if "tools_condition" in str(chunk["metadata"]["langgraph_triggers"]): + continue # Skip Tool Call messages + if "vs_retrieve" in str(chunk["metadata"]["langgraph_node"]): + continue # Skip Fake-Tool Call messages + content = chunk["data"]["chunk"].content + if content != "" and call == "streams": + yield content.encode("utf-8") + last_response = chunk["data"] + except oci.exceptions.ServiceError as ex: + error_details = json.loads(ex.message).get("message", "") + yield error_response(call, error_details, model) + raise + + # Clean Up + if call == "streams": + yield "[stream_finished]" # This will break the Chatbot loop + elif call == "completions": + final_response = last_response["output"]["final_response"] + yield final_response # This will be captured for ChatResponse diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index ae2dd93d..36bfc586 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -16,8 +16,7 @@ import server.api.core.models as core_models from common.functions import is_url_accessible -from common import schema -from common import logging_config +from common import logging_config, schema logger = logging_config.logging.getLogger("api.utils.models") @@ -159,19 +158,18 @@ def get_client_embed(model_config: dict, oci_config: schema.OracleCloudSettings) else: if provider == "hosted_vllm": kwargs = { - "provider": "openai", - "model": full_model_config["id"], - "base_url": full_model_config.get("api_base"), - "check_embedding_ctx_length":False #To avoid Tiktoken pre-transform on not OpenAI provided server + "provider": "openai", + "model": full_model_config["id"], + "base_url": full_model_config.get("api_base"), + "check_embedding_ctx_length": False, # To avoid Tiktoken pre-transform on not OpenAI provided server } else: kwargs = { - "provider": provider, - "model": full_model_config["id"], - "base_url": full_model_config.get("api_base"), + "provider": provider, + "model": full_model_config["id"], + "base_url": full_model_config.get("api_base"), } - if full_model_config.get("api_key"): # only add if set kwargs["api_key"] = full_model_config["api_key"] client = init_embeddings(**kwargs) diff --git a/src/server/api/utils/testbed.py b/src/server/api/utils/testbed.py index 84528ecb..f12782ec 100644 --- a/src/server/api/utils/testbed.py +++ b/src/server/api/utils/testbed.py @@ -20,8 +20,8 @@ import server.api.utils.databases as utils_databases import server.api.utils.models as utils_models -from common import schema -from common import logging_config + +from common import schema, logging_config logger = logging_config.logging.getLogger("api.utils.testbed") diff --git a/src/server/api/v1/__init__.py b/src/server/api/v1/__init__.py index fcd6743f..f9da75e0 100644 --- a/src/server/api/v1/__init__.py +++ b/src/server/api/v1/__init__.py @@ -2,5 +2,6 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# spell-checker:ignore selectai -from . import chat, databases, embed, models, oci, probes, prompts, testbed, settings, selectai +from . import chat, databases, embed, models, oci, probes, prompts, testbed, settings, selectai, mcp diff --git a/src/server/api/v1/chat.py b/src/server/api/v1/chat.py index 0d4379d2..b9b8bbee 100644 --- a/src/server/api/v1/chat.py +++ b/src/server/api/v1/chat.py @@ -18,11 +18,10 @@ from langgraph.graph.state import CompiledStateGraph from langgraph.graph.message import REMOVE_ALL_MESSAGES -from server.api.utils import chat -from server.agents import chatbot +import server.api.utils.mcp as utils_mcp +import server.mcp.graph as graph -from common import schema -from common import logging_config +from common import schema, logging_config logger = logging_config.logging.getLogger("endpoints.v1.chat") @@ -39,7 +38,7 @@ async def chat_post( ) -> ModelResponse: """Full Completion Requests""" last_message = None - async for chunk in chat.completion_generator(client, request, "completions"): + async for chunk in utils_mcp.completion_generator(client, request, "completions"): last_message = chunk return last_message @@ -55,7 +54,7 @@ async def chat_stream( ) -> StreamingResponse: """Completion Requests""" return StreamingResponse( - chat.completion_generator(client, request, "streams"), + utils_mcp.completion_generator(client, request, "streams"), media_type="application/octet-stream", ) @@ -67,7 +66,8 @@ async def chat_stream( ) async def chat_history_clean(client: schema.ClientIdType = Header(default="server")) -> list[ChatMessage]: """Delete all Chat History""" - agent: CompiledStateGraph = chatbot.chatbot_graph + agent: CompiledStateGraph = graph.main(list()) + # agent: CompiledStateGraph = chatbot.chatbot_graph try: _ = agent.update_state( config=RunnableConfig( @@ -89,7 +89,8 @@ async def chat_history_clean(client: schema.ClientIdType = Header(default="serve ) async def chat_history_return(client: schema.ClientIdType = Header(default="server")) -> list[ChatMessage]: """Return Chat History""" - agent: CompiledStateGraph = chatbot.chatbot_graph + agent: CompiledStateGraph = graph.main(list()) + # agent: CompiledStateGraph = chatbot.chatbot_graph try: state_snapshot = agent.get_state( config=RunnableConfig( diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py index 221c2ccf..fd88cae4 100644 --- a/src/server/api/v1/databases.py +++ b/src/server/api/v1/databases.py @@ -8,8 +8,7 @@ import server.api.utils.databases as utils_databases -from common import schema -from common import logging_config +from common import schema, logging_config logger = logging_config.logging.getLogger("endpoints.v1.databases") diff --git a/src/server/api/v1/mcp.py b/src/server/api/v1/mcp.py new file mode 100644 index 00000000..68e40b9a --- /dev/null +++ b/src/server/api/v1/mcp.py @@ -0,0 +1,131 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +This file is being used in APIs, and not the backend.py file. +""" + +# spell-checker:ignore noauth fastmcp healthz +from fastapi import APIRouter, Request, Depends +from fastmcp import FastMCP, Client + +import server.api.utils.mcp as utils_mcp + +from common import logging_config +from fastapi import APIRouter, Request, Depends +from fastmcp import FastMCP, Client + +import server.api.core.mcp as core_mcp + +import common.logging_config as logging_config + +logger = logging_config.logging.getLogger("api.v1.mcp") + +auth = APIRouter() + + +def get_mcp(request: Request) -> FastMCP: + """Get the MCP engine from the app state""" + return request.app.state.fastmcp_app + + +@auth.get( + "/client", + description="Get MCP Client Configuration", + response_model=dict, +) +async def get_client(server: str = None, port: int = None) -> dict: + "Get MCP Client Configuration" + return utils_mcp.get_client(server, port) + + + +@auth.get( + "/tools", + description="List available MCP tools", + response_model=list[dict], +) +async def get_tools(mcp_engine: FastMCP = Depends(get_mcp)) -> list[dict]: + """List MCP tools""" + tools_info = [] + try: + client = Client(mcp_engine) + async with client: + tools = await client.list_tools() + logger.debug("MCP Tools: %s", tools) + for tool_object in tools: + tools_info.append(tool_object.model_dump()) + finally: + await client.close() + + return tools_info + + +@auth.get( + "/resources", + description="List MCP resources", + response_model=list[dict], +) +async def mcp_list_resources(mcp_engine: FastMCP = Depends(get_mcp)) -> list[dict]: + """List MCP Resources""" + resources_info = [] + try: + client = Client(mcp_engine) + async with client: + resources = await client.list_resources() + logger.debug("MCP Resources: %s", resources) + for resources_object in resources: + resources_info.append(resources_object.model_dump()) + finally: + await client.close() + + return resources_info + + +@auth.get( + "/prompts", + description="List MCP prompts", + response_model=list[dict], +) +async def mcp_list_prompts(mcp_engine: FastMCP = Depends(get_mcp)) -> list[dict]: + """List MCP Prompts""" + prompts_info = [] + try: + client = Client(mcp_engine) + async with client: + prompts = await client.list_prompts() + logger.debug("MCP Resources: %s", prompts) + for prompts_object in prompts: + prompts_info.append(prompts_object.model_dump()) + finally: + await client.close() + + return prompts_info + + +# @auth.post("/execute", description="Execute an MCP tool", response_model=dict) +# async def mcp_execute_tool(request: McpToolCallRequest): +# """Execute MCP Tool""" +# mcp_engine = mcp_engine_obj() +# if not mcp_engine: +# raise HTTPException(status_code=503, detail="MCP Engine not initialized.") +# try: +# result = await mcp_engine.execute_mcp_tool(request.tool_name, request.tool_args) +# return {"result": result} +# except Exception as ex: +# logger.error("Error executing MCP tool: %s", ex) +# raise HTTPException(status_code=500, detail=str(ex)) from ex + + +# @auth.post("/chat", description="Chat with MCP engine", response_model=dict) +# async def chat_endpoint(request: ChatRequest): +# """Chat with MCP Engine""" +# mcp_engine = mcp_engine_obj() +# if not mcp_engine: +# raise HTTPException(status_code=503, detail="MCP Engine not initialized.") +# try: +# message_history = request.message_history or [{"role": "user", "content": request.query}] +# response_text, _ = await mcp_engine.invoke(message_history=message_history) +# return {"response": response_text} +# except Exception as ex: +# logger.error("Error in MCP chat: %s", ex) +# raise HTTPException(status_code=500, detail=str(ex)) from ex diff --git a/src/server/api/v1/oci.py b/src/server/api/v1/oci.py index 56a7f240..3151877e 100644 --- a/src/server/api/v1/oci.py +++ b/src/server/api/v1/oci.py @@ -12,8 +12,7 @@ import server.api.utils.oci as utils_oci import server.api.utils.models as utils_models -from common import schema -from common import logging_config +from common import schema, logging_config logger = logging_config.logging.getLogger("endpoints.v1.oci") diff --git a/src/server/api/v1/probes.py b/src/server/api/v1/probes.py index 6dba7c3d..34a986c6 100644 --- a/src/server/api/v1/probes.py +++ b/src/server/api/v1/probes.py @@ -2,13 +2,20 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -# spell-checker:ignore noauth -from fastapi import APIRouter +# spell-checker:ignore noauth fastmcp healthz +from datetime import datetime +from fastapi import APIRouter, Request, Depends +from fastmcp import FastMCP noauth = APIRouter() +def get_mcp(request: Request) -> FastMCP: + """Get the MCP engine from the app state""" + return request.app.state.fastmcp_app + + @noauth.get("/liveness") async def liveness_probe(): """Kubernetes liveness probe""" @@ -19,3 +26,19 @@ async def liveness_probe(): async def readiness_probe(): """Kubernetes readiness probe""" return {"status": "ready"} + + +@noauth.get("/mcp/healthz") +def mcp_healthz(mcp_engine: FastMCP = Depends(get_mcp)): + """Check if MCP server is ready.""" + if mcp_engine is None: + return {"status": "not ready"} + + server = mcp_engine.__dict__["_mcp_server"].__dict__ + return { + "status": "ready", + "name": server["name"], + "version": server["version"], + "available_tools": len(getattr(mcp_engine, "available_tools", [])) if mcp_engine else 0, + "timestamp": datetime.now().isoformat(), + } diff --git a/src/server/api/v1/prompts.py b/src/server/api/v1/prompts.py index 64713b61..c6362fb7 100644 --- a/src/server/api/v1/prompts.py +++ b/src/server/api/v1/prompts.py @@ -9,8 +9,7 @@ import server.api.core.prompts as core_prompts -from common import schema -from common import logging_config +from common import schema, logging_config logger = logging_config.logging.getLogger("endpoints.v1.prompts") diff --git a/src/server/api/v1/settings.py b/src/server/api/v1/settings.py index be810137..512303f4 100644 --- a/src/server/api/v1/settings.py +++ b/src/server/api/v1/settings.py @@ -11,8 +11,7 @@ import server.api.core.settings as core_settings -from common import schema -from common import logging_config +from common import logging_config, schema logger = logging_config.logging.getLogger("endpoints.v1.settings") @@ -37,7 +36,7 @@ async def settings_get( full_config: bool = False, incl_sensitive: bool = Depends(_incl_sensitive_param), incl_readonly: bool = Depends(_incl_readonly_param), -) -> Union[schema.Configuration, schema.Settings]: +) -> Union[schema.Configuration, schema.Settings, JSONResponse]: """Get settings for a specific client by name""" try: client_settings = core_settings.get_client_settings(client) @@ -54,8 +53,13 @@ async def settings_get( model_configs=config.get("model_configs"), oci_configs=config.get("oci_configs"), prompt_configs=config.get("prompt_configs"), + mcp_configs=config.get("mcp_configs", None), ) - return JSONResponse(content=response.model_dump_public(incl_sensitive=incl_sensitive, incl_readonly=incl_readonly)) + if incl_sensitive or incl_readonly: + return JSONResponse( + content=response.model_dump_public(incl_sensitive=incl_sensitive, incl_readonly=incl_readonly) + ) + return response @auth.patch( @@ -113,12 +117,12 @@ async def load_settings_from_file( pass try: - if not file.filename.endswith(".json"): + if not file.filename or not file.filename.endswith(".json"): raise HTTPException(status_code=400, detail="Settings: Only JSON files are supported.") contents = await file.read() config_data = json.loads(contents) core_settings.load_config_from_json_data(config_data, client) - return {"message": "Configuration loaded successfully."} + return JSONResponse(content={"message": "Configuration loaded successfully."}) except json.JSONDecodeError as ex: raise HTTPException(status_code=400, detail="Settings: Invalid JSON file.") from ex except KeyError as ex: @@ -147,7 +151,7 @@ async def load_settings_from_json( try: core_settings.load_config_from_json_data(payload.model_dump(), client) - return {"message": "Configuration loaded successfully."} + return JSONResponse(content={"message": "Configuration loaded successfully."}) except json.JSONDecodeError as ex: raise HTTPException(status_code=400, detail="Settings: Invalid JSON file.") from ex except KeyError as ex: diff --git a/src/server/api/v1/testbed.py b/src/server/api/v1/testbed.py index 0f38fbce..9d28bfbf 100644 --- a/src/server/api/v1/testbed.py +++ b/src/server/api/v1/testbed.py @@ -27,8 +27,7 @@ from server.api.v1 import chat -from common import schema -from common import logging_config +from common import logging_config, schema logger = logging_config.logging.getLogger("endpoints.v1.testbed") @@ -90,7 +89,9 @@ async def testbed_testset_qa( client: schema.ClientIdType = Header(default="server"), ) -> schema.TestSetQA: """Get TestSet Q&A""" - return utils_testbed.get_testset_qa(db_conn=utils_databases.get_client_database(client).connection, tid=tid.upper()) + return utils_testbed.get_testset_qa( + db_conn=utils_databases.get_client_database(client).connection, tid=tid.upper() + ) @auth.delete( diff --git a/src/server/bootstrap/mcp.py b/src/server/bootstrap/mcp.py new file mode 100644 index 00000000..c958e102 --- /dev/null +++ b/src/server/bootstrap/mcp.py @@ -0,0 +1,89 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" + +from typing import List, Optional +import os + +from server.bootstrap.configfile import ConfigStore +from common.schema import MCPSettings, MCPModelConfig, MCPToolConfig +from common import logging_config + +logger = logging_config.logging.getLogger("bootstrap.mcp") + +# Global configuration holders +MCP_SETTINGS: Optional[MCPSettings] = None +MCP_MODELS: List[MCPModelConfig] = [] +MCP_TOOLS: List[MCPToolConfig] = [] + + +def load_mcp_settings(config: dict) -> None: + """Load MCP configuration from config file""" + global MCP_SETTINGS, MCP_MODELS, MCP_TOOLS + + # Convert to settings object first + mcp_settings = MCPSettings( + models=[MCPModelConfig(**model) for model in config.get("models", [])], + tools=[MCPToolConfig(**tool) for tool in config.get("tools", [])], + default_model=config.get("default_model"), + enabled=config.get("enabled", True), + ) + + # Set globals + MCP_SETTINGS = mcp_settings + MCP_MODELS = mcp_settings.models + MCP_TOOLS = mcp_settings.tools + + logger.info("Loaded %i MCP Models and %i Tools", len(MCP_MODELS), len(MCP_TOOLS)) + + +def main() -> MCPSettings: + """Bootstrap MCP Configuration""" + logger.debug("*** Bootstrapping MCP - Start") + + # Load from ConfigStore if available + configuration = ConfigStore.get() + if configuration and configuration.mcp_configs: + logger.debug("Using MCP configs from ConfigStore") + # Convert list of MCPModelConfig objects to MCPSettings + mcp_settings = MCPSettings( + models=configuration.mcp_configs, + tools=[], # No tools in the current schema + default_model=configuration.mcp_configs[0].model_id if configuration.mcp_configs else None, + enabled=True, + ) + else: + # Default MCP configuration + mcp_settings = MCPSettings( + models=[ + MCPModelConfig( + model_id="llama3.1", + service_type="ollama", + base_url=os.environ.get("ON_PREM_OLLAMA_URL", "http://localhost:11434"), + enabled=True, + streaming=False, + temperature=1.0, + max_tokens=2048, + ) + ], + tools=[ + MCPToolConfig( + name="file_reader", + description="Read contents of files", + parameters={"path": "string", "encoding": "string"}, + enabled=True, + ) + ], + default_model=None, + enabled=True, + ) + + logger.info("Loaded %i MCP Models and %i Tools", len(mcp_settings.models), len(mcp_settings.tools)) + logger.debug("*** Bootstrapping MCP - End") + logger.info("MCP Settings: %s", mcp_settings.model_dump_json()) + return mcp_settings + + +if __name__ == "__main__": + main() diff --git a/src/server/bootstrap/oci.py b/src/server/bootstrap/oci.py index c40c297f..dd2d7470 100644 --- a/src/server/bootstrap/oci.py +++ b/src/server/bootstrap/oci.py @@ -10,9 +10,9 @@ from server.bootstrap.configfile import ConfigStore -from common import logging_config from common.schema import OracleCloudSettings +from common import logging_config logger = logging_config.logging.getLogger("bootstrap.oci") diff --git a/src/server/mcp/__init__.py b/src/server/mcp/__init__.py new file mode 100644 index 00000000..8cb3ea53 --- /dev/null +++ b/src/server/mcp/__init__.py @@ -0,0 +1,73 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker:ignore fastapi fastmcp + +import importlib +import pkgutil + +from fastapi import APIRouter +from fastmcp import FastMCP + +from common import logging_config + +logger = logging_config.logging.getLogger("mcp.__init__.py") + + +async def _discover_and_register( + package: str, + mcp: FastMCP = None, + auth: APIRouter = None, +): + """Import all modules in a package and call their register function.""" + try: + pkg = importlib.import_module(package) + except ImportError: + logger.warning("Package %s not found, skipping.", package) + return + + for module_info in pkgutil.walk_packages(pkg.__path__, prefix=f"{package}."): + if module_info.name.endswith("__init__"): + continue + + try: + module = importlib.import_module(module_info.name) + except Exception as ex: + logger.error("Failed to import %s: %s", module_info.name, ex) + continue + + # Decide what to register based on available functions + if hasattr(module, "register"): + logger.info("Registering via %s.register()", module_info.name) + if ".tools." in module.__name__: + await module.register(mcp, auth) + if ".proxies." in module.__name__: + await module.register(mcp) + if ".prompts." in module.__name__: + await module.register(mcp) + # elif hasattr(module, "register_tool"): + # logger.info("Registering tool via %s.register_tool()", module_info.name) + # module.register_tool(mcp, auth) + # elif hasattr(module, "register_prompt"): + # logger.info("Registering prompt via %s.register_prompt()", module_info.name) + # module.register_prompt(mcp) + # elif hasattr(module, "register_resource"): + # logger.info("Registering resource via %s.register_resource()", module_info.name) + # module.register_resource(mcp) + # elif hasattr(module, "register_proxy"): + # logger.info("Registering proxy via %s.register_resource()", module_info.name) + # module.register_resource(mcp) + else: + logger.debug("No register function in %s, skipping.", module_info.name) + + +async def register_all_mcp(mcp: FastMCP, auth: APIRouter): + """ + Auto-discover and register all MCP tools, prompts, resources, and proxies. + """ + logger.info("Starting Registering MCP Components") + await _discover_and_register("server.mcp.tools", mcp=mcp, auth=auth) + await _discover_and_register("server.mcp.proxies", mcp=mcp) + # await _discover_and_register("server.mcp.prompts", mcp=mcp) + logger.info("Finished Registering MCP Components") diff --git a/src/server/mcp/graph.py b/src/server/mcp/graph.py new file mode 100644 index 00000000..00664032 --- /dev/null +++ b/src/server/mcp/graph.py @@ -0,0 +1,171 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker:ignore ainvoke checkpointer + +from datetime import datetime, timezone + +from langchain_core.messages import SystemMessage, ToolMessage +from langchain_core.runnables import RunnableConfig + +from langgraph.graph import StateGraph, MessagesState, START, END +from langgraph.prebuilt import ToolNode, tools_condition + +from common.schema import ChatResponse, ChatUsage, ChatChoices, ChatMessage +from launch_server import graph_memory + +import common.logging_config as logging_config + +logger = logging_config.logging.getLogger("mcp.graph") + + +############################################################################# +# AGENT STATE +############################################################################# +class OptimizerState(MessagesState): + """Establish our Agent State Machine""" + + final_response: ChatResponse # OpenAI Response + cleaned_messages: list # Messages w/o VS Results + + +############################################################################# +# NODES and EDGES +############################################################################# +def respond(state: OptimizerState, config: RunnableConfig) -> ChatResponse: + """Respond in OpenAI Compatible return""" + ai_message = state["messages"][-1] + logger.debug("Formatting to OpenAI compatible response: %s", repr(ai_message)) + if "model_name" in ai_message.response_metadata: + model_id = ai_message.response_metadata["model_name"] + ai_metadata = ai_message + else: + logger.debug("Using Metadata from: %s", repr(ai_metadata)) + model_id = config["metadata"]["ll_model"] + ai_metadata = state["messages"][1] + + finish_reason = ai_metadata.response_metadata.get("finish_reason", "stop") + if finish_reason == "COMPLETE": + finish_reason = "stop" + elif finish_reason == "MAX_TOKENS": + finish_reason = "length" + + openai_response = ChatResponse( + id=ai_message.id, + created=int(datetime.now(timezone.utc).timestamp()), + model=model_id, + usage=ChatUsage( + prompt_tokens=ai_metadata.response_metadata.get("token_usage", {}).get("prompt_tokens", -1), + completion_tokens=ai_metadata.response_metadata.get("token_usage", {}).get("completion_tokens", -1), + total_tokens=ai_metadata.response_metadata.get("token_usage", {}).get("total_tokens", -1), + ), + choices=[ + ChatChoices( + index=0, + message=ChatMessage( + role="ai", + content=ai_message.content, + additional_kwargs=ai_metadata.additional_kwargs, + response_metadata=ai_metadata.response_metadata, + ), + finish_reason=finish_reason, + logprobs=None, + ) + ], + ) + return {"final_response": openai_response} + + +async def client(state: OptimizerState, config: RunnableConfig) -> OptimizerState: + """Get messages from state based on Thread ID""" + logger.debug("Initializing OptimizerState") + messages = get_messages(state, config) + + return {"cleaned_messages": messages} + + +############################################################################# +def get_messages(state: OptimizerState, config: RunnableConfig) -> list: + """Return a list of messages that will be passed to the model for completion + Leave the state as is for GUI functionality""" + use_history = config["metadata"]["use_history"] + + # If user decided for no history, only take the last message + state_messages = state["messages"] if use_history else state["messages"][-1:] + + messages = [] + for msg in state_messages: + if isinstance(msg, SystemMessage): + continue + if isinstance(msg, ToolMessage): + if messages: # Check if there are any messages in the list + messages.pop() # Remove the last appended message + continue + messages.append(msg) + + # # insert the system prompt; remaining messages cleaned + # if config["metadata"]["sys_prompt"].prompt: + # messages.insert(0, SystemMessage(content=config["metadata"]["sys_prompt"].prompt)) + + return messages + + +def should_continue(state: OptimizerState): + """Determine if graph should continue to tools""" + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return "tools" + return END + + +# Define call_model function +async def call_model(state: OptimizerState, config: RunnableConfig): + """Invoke the model""" + try: + model = config["configurable"].get("ll_model", None) + messages = state["messages"] + response = await model.ainvoke(messages) + return {"messages": [response]} + except AttributeError as ex: + # The model is not in our RunnableConfig + return {"messages": [f"I'm sorry; {ex}"]} + + +# ############################################################################# +# # GRAPH +# ############################################################################# +def main(tools: list): + """Define the graph with MCP tool nodes""" + # Build the graph + workflow = StateGraph(OptimizerState) + + # Define the nodes + workflow.add_node("client", client) + workflow.add_node("call_model", call_model) + workflow.add_node("tools", ToolNode(tools)) + workflow.add_node("respond", respond) + + # Add Edges + workflow.add_edge(START, "client") + workflow.add_edge("client", "call_model") + workflow.add_conditional_edges( + "call_model", + should_continue, + ) + workflow.add_edge("tools", "call_model") + workflow.add_edge("call_model", "respond") + workflow.add_edge("respond", END) + + # Compile the graph and return it + mcp_graph = workflow.compile(checkpointer=graph_memory) + logger.debug("Chatbot Graph Built with tools: %s", tools) + ## This will output the Graph in ascii; don't deliver uncommented + # mcp_graph.get_graph(xray=True).print_ascii() + + return mcp_graph + + +if __name__ == "__main__": + main(list()) diff --git a/src/server/mcp/prompts/__init__.py b/src/server/mcp/prompts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/server/mcp/prompts/optimizer.py b/src/server/mcp/prompts/optimizer.py new file mode 100644 index 00000000..52e03226 --- /dev/null +++ b/src/server/mcp/prompts/optimizer.py @@ -0,0 +1,21 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" + +# pylint: disable=unused-argument +# spell-checker:ignore fastmcp +from fastmcp.prompts.prompt import PromptMessage, TextContent + + +# Basic prompt returning a string (converted to user message automatically) +async def register(mcp): + """Register Out-of-Box Prompts""" + optimizer_tags = {"source", "optimizer"} + + @mcp.prompt(name="basic-example-chatbot", tags=optimizer_tags) + def basic_example() -> PromptMessage: + """Basic system prompt for chatbot.""" + + content = "You are a friendly, helpful assistant." + return PromptMessage(role="system", content=TextContent(type="text", text=content)) diff --git a/src/server/mcp/proxies/__init__.py b/src/server/mcp/proxies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/server/mcp/proxies/sqlcl.py b/src/server/mcp/proxies/sqlcl.py new file mode 100644 index 00000000..a9f3e445 --- /dev/null +++ b/src/server/mcp/proxies/sqlcl.py @@ -0,0 +1,72 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker:ignore sqlcl fastmcp connmgr noupdates savepwd + +import os +import shutil +import subprocess + +import server.api.utils.databases as utils_databases + +from common import logging_config + +logger = logging_config.logging.getLogger("mcp.proxies.sqlcl") + + +async def register(mcp): + """Register the SQLcl MCP Server as Local (via Proxy)""" + tool_name = "SQLclProxy" + + sqlcl_binary = shutil.which("sql") + if sqlcl_binary: + env_vars = os.environ.copy() + env_vars["TNS_ADMIN"] = os.getenv("TNS_ADMIN", "tns_admin") + config = { + "mcpServers": { + tool_name: { + "name": tool_name, + "command": f"{sqlcl_binary}", + "args": ["-mcp", "-daemon", "-thin", "-noupdates"], + "env": env_vars, + } + } + } + databases = utils_databases.get_databases(validate=False) + for database in databases: + # Start sql in no-login mode + try: + proc = subprocess.Popen( + [sqlcl_binary, "/nolog"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=env_vars, + ) + + # Prepare commands: connect, then exit + commands = [ + f"connmgr delete -conn OPTIMIZER_{database.name}", + ( + f"conn -savepwd -save OPTIMIZER_{database.name} " + f"-user {database.user} -password {database.password} " + f"-url {database.dsn}" + ), + "exit", + ] + + # Send commands joined by newlines + proc.communicate("\n".join(commands) + "\n") + logger.info("Established Connection Store for: %s", database.name) + except subprocess.SubprocessError as ex: + logger.error("Failed to create connection store: %s", ex) + except Exception as ex: + logger.error("Unexpected error creating connection store: %s", ex) + + # Create a proxy to the configured server (auto-creates ProxyClient) + proxy = mcp.as_proxy(config, name=tool_name) + mcp.mount(proxy, as_proxy=False, prefix="sqlcl") + else: + logger.warning("Not enabling SQLcl MCP server, sqlcl not found in PATH.") diff --git a/src/server/mcp/resources/__init__.py b/src/server/mcp/resources/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/server/mcp/tools/__init__.py b/src/server/mcp/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/server/mcp/tools/say_hello.py b/src/server/mcp/tools/say_hello.py new file mode 100644 index 00000000..b581cd90 --- /dev/null +++ b/src/server/mcp/tools/say_hello.py @@ -0,0 +1,6 @@ +async def register(mcp, auth): + @mcp.tool(name="optimizer_greet") + @auth.get("/hello", operation_id="say_hello") + def greet(name: str = "World") -> str: + """Say hello to someone.""" + return f"Hello, {name}!" diff --git a/src/server/mcp_bak/register_mcp_servers.py b/src/server/mcp_bak/register_mcp_servers.py new file mode 100644 index 00000000..c08a59b2 --- /dev/null +++ b/src/server/mcp_bak/register_mcp_servers.py @@ -0,0 +1,36 @@ +from fastapi import FastAPI, APIRouter, Request +from fastapi.responses import JSONResponse, PlainTextResponse +from mcp.server import Server +import json + +def mount_mcp(router: APIRouter, prefix: str, mcp_server: Server): + @router.get(f"{prefix}/.well-known/mcp.json") + async def manifest(): + return JSONResponse(content=mcp_server.manifest.dict()) + + @router.post(f"{prefix}/mcp") + async def mcp_api(request: Request): + body = await request.body() + resp = mcp_server.handle_http(body) + try: + return JSONResponse(content=json.loads(resp)) + except Exception: + return PlainTextResponse(content=resp) + +def register_mcp_servers(app: FastAPI): + # Create routers for MCP endpoints + mcp_router = APIRouter() + + # Define MCP servers + mcp_sqlcl = Server(name="Built-in SQLcl MCP Server") + + # Example tools + @mcp.tool() + def greet(name: str) -> str: + return f"Hello from MCP Server One, {name}!" + + # Mount MCP servers into the router under prefixes + mount_mcp(app, "/mcp_sqlcl", mcp_sqlcl) + + # Include the MCP router into the main app + app.include_router(mcp_router) \ No newline at end of file diff --git a/src/server/mcp_bak/server/archive_mcp.py b/src/server/mcp_bak/server/archive_mcp.py new file mode 100644 index 00000000..d38a091f --- /dev/null +++ b/src/server/mcp_bak/server/archive_mcp.py @@ -0,0 +1,182 @@ +import json +import os +from dotenv import load_dotenv +import arxiv +from typing import List +from mcp.server.fastmcp import FastMCP +import textwrap + +# --- Configuration and Setup --- +load_dotenv() +PAPER_DIR = "papers" +# Initialize FastMCP server with a name +mcp = FastMCP("research") +_paper_cache = {} + +# --- Tool Definitions --- + +@mcp.tool() +def search_papers(topic: str, max_results: int = 5) -> List[str]: + """ + Searches for papers on arXiv based on a topic and saves their metadata. + + Args: + topic (str): The topic to search for. + max_results (int): Maximum number of results to retrieve. + + Returns: + List[str]: A list of the paper IDs found and saved. + """ + client_arxiv = arxiv.Client() + search = arxiv.Search( + query=topic, + max_results=max_results, + sort_by=arxiv.SortCriterion.Relevance + ) + papers = list(client_arxiv.results(search)) + + if not papers: + # It's good practice to print feedback on the server side + print(f"Server: No papers found for topic '{topic}'") + return [] + + path = os.path.join(PAPER_DIR, topic.lower().replace(" ", "_")) + os.makedirs(path, exist_ok=True) + file_path = os.path.join(path, "papers_info.json") + + try: + with open(file_path, "r") as json_file: + papers_info = json.load(json_file) + except (FileNotFoundError, json.JSONDecodeError): + papers_info = {} + + paper_ids = [] + for paper in papers: + paper_id = paper.get_short_id() + paper_ids.append(paper_id) + papers_info[paper_id] = { + 'title': paper.title, + 'authors': [author.name for author in paper.authors], + 'summary': paper.summary, + 'pdf_url': paper.pdf_url, + 'published': str(paper.published.date()) + } + + with open(file_path, "w") as json_file: + json.dump(papers_info, json_file, indent=2) + + print(f"Server: Saved {len(paper_ids)} papers to {file_path}") + return paper_ids + +@mcp.tool() +def extract_info(paper_id: str) -> str: + """ + Retrieves saved information for a specific paper ID from all topics. + Uses an in-memory cache for performance. + + Args: + paper_id (str): The ID of the paper to look for. + + Returns: + str: JSON string with paper information if found, else an error message. + """ + # 1. First, check the cache for an exact match + if paper_id in _paper_cache: + return json.dumps(_paper_cache[paper_id], indent=2) + + # 2. If not in cache, perform the expensive file search (your original logic) + for item in os.listdir(PAPER_DIR): + item_path = os.path.join(PAPER_DIR, item) + if os.path.isdir(item_path): + file_path = os.path.join(item_path, "papers_info.json") + if os.path.isfile(file_path): + try: + with open(file_path, "r") as json_file: + papers_info = json.load(json_file) + + # Search logic (can be simplified if we populate cache first) + for key, value in papers_info.items(): + # Add every paper from this file to the cache to avoid re-reading this file + if key not in _paper_cache: + _paper_cache[key] = value + + except (FileNotFoundError, json.JSONDecodeError): + continue + + # 3. Now that the cache is populated from relevant files, check again. + # This handles version differences as well. + if paper_id in _paper_cache: + return json.dumps(_paper_cache[paper_id], indent=2) + + base_id = paper_id.split('v')[0] + for key, value in _paper_cache.items(): + if key.startswith(base_id): + return json.dumps(value, indent=2) + + return f"Error: No saved information found for paper ID {paper_id}." + +# --- Resource Definitions --- + +@mcp.resource("papers://folders") +def get_available_folders() -> str: + """Lists all available topic folders that contain saved paper information.""" + print(f"Server: Listing available topic folders in {PAPER_DIR}") + folders = [] + if os.path.exists(PAPER_DIR): + for topic_dir in os.listdir(PAPER_DIR): + if os.path.isdir(os.path.join(PAPER_DIR, topic_dir)): + folders.append(topic_dir) + + content = "# Available Research Topics\n\n" + if folders: + content += "You can retrieve info for any of these topics using `@`.\n\n" + for folder in folders: + content += f"- `{folder}`\n" + else: + content += "No topic folders found. Use `search_papers` to create one." + print(f"Server: Found {len(folders)} topic folders.") + return content + +@mcp.resource("papers://{topic}") +def get_topic_papers(topic: str) -> str: + """Gets detailed information about all saved papers for a specific topic.""" + print(f"Server: Retrieving papers for topic '{topic}'") + topic_dir = topic.lower().replace(" ", "_") + papers_file = os.path.join(PAPER_DIR, topic_dir, "papers_info.json") + + if not os.path.exists(papers_file): + return f"# No papers found for topic: {topic}" + + with open(papers_file, 'r') as f: + papers_data = json.load(f) + + content = f"# Papers on {topic.replace('_', ' ').title()}\n\n" + for paper_id, info in papers_data.items(): + content += f"## {info['title']} (`{paper_id}`)\n" + content += f"- **Authors**: {', '.join(info['authors'])}\n" + content += f"- **Summary**: {info['summary'][:200]}...\n---\n" + print(f"Server: Found {len(papers_data)} papers for topic '{topic}'") + return content + +# --- Prompt Definition --- + +@mcp.prompt() +def generate_search_prompt(topic: str) -> str: + """Generates a system prompt to guide an AI in researching a topic.""" + return textwrap.dedent(f""" + You are a research assistant. Your goal is to provide a comprehensive overview of a topic. + When asked about '{topic}', follow these steps: + 1. Use the `search_papers` tool to find relevant papers. + 2. For each paper ID returned, use the `extract_info` tool to get its details. + 3. Synthesize the information from all papers into a cohesive summary. + 4. Present the key findings, common themes, and any differing conclusions. + Do not present the raw JSON. Format the final output for readability. + """) + +# --- Main Execution Block --- + +if __name__ == "__main__": + # This is the original, simple, and correct way to run the server. + # It will not crash. + print("Research MCP Server running on stdio...") + mcp.run(transport='stdio') diff --git a/src/server/mcp_bak/server_config.json b/src/server/mcp_bak/server_config.json new file mode 100644 index 00000000..3d8b0321 --- /dev/null +++ b/src/server/mcp_bak/server_config.json @@ -0,0 +1,20 @@ +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "." + ] + }, + "research": { + "command": "python3", + "args": ["server/mcp/server/archive_mcp.py"] + }, + "fetch": { + "command": "python3", + "args": ["-m", "mcp_server_fetch"] + } + } +} diff --git a/tests/conftest.py b/tests/conftest.py index 3c8b94f6..ec1b5bf9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ # pylint: disable=import-outside-toplevel import os +import asyncio # This contains all the environment variables we consume on startup (add as required) # Used to clear testing environment from users env; Do before any additional imports @@ -72,7 +73,7 @@ def client(): # Lazy Load from launch_server import create_app - app = create_app() + app = asyncio.run(create_app()) return TestClient(app) @@ -110,14 +111,14 @@ def is_port_in_use(port): server_process = subprocess.Popen(cmd, cwd="src") - # Wait for server to be ready (up to 30 seconds) - max_wait = 30 + # Wait for server to be ready + max_wait = 60 start_time = time.time() while not is_port_in_use(8015): if time.time() - start_time > max_wait: server_process.terminate() server_process.wait() - raise TimeoutError("Server failed to start within 30 seconds") + raise TimeoutError(f"Server failed to start within {max_wait} seconds") time.sleep(0.5) yield server_process @@ -132,7 +133,7 @@ def app_test(auth_headers): """Establish Streamlit State for Client to Operate""" def _app_test(page): - at = AppTest.from_file(page, default_timeout=30) + at = AppTest.from_file(page, default_timeout=60) at.session_state.server = { "key": os.environ.get("API_SERVER_KEY"), "url": os.environ.get("API_SERVER_URL"), diff --git a/tests/unit/server/api/utils/test_models.py b/tests/unit/server/api/utils/test_models.py new file mode 100644 index 00000000..04728f13 --- /dev/null +++ b/tests/unit/server/api/utils/test_models.py @@ -0,0 +1,36 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" + +# spell-checker: disable +import os +import pytest + +import server.api.core.models as core_models +import server.api.utils.models as utils_models + +os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" +os.environ["LITELLM_DISABLE_SPEND_LOGS"] = "True" +os.environ["LITELLM_DISABLE_SPEND_UPDATES"] = "True" +os.environ["LITELLM_DISABLE_END_USER_COST_TRACKING"] = "True" +os.environ["LITELLM_DROP_PARAMS"] = "True" +os.environ["LITELLM_DROP_PARAMS"] = "True" + +@pytest.fixture(name="models_list") +def _models_list(): + model_objects = core_models.get_model() + for obj in model_objects: + obj.enabled = True + return model_objects + + +def test_get_litellm_client(models_list): + """Testing LiteLLM Functionality""" + assert isinstance(models_list, list) + assert len(models_list) > 0 + + for model in models_list: + print(f"My Model: {model}") + if model.id == "mxbai-embed-large": + utils_models.get_litellm_client(model.dict())