diff --git a/llm-service/app/routers/index/__init__.py b/llm-service/app/routers/index/__init__.py index d530e011f..99d45f843 100644 --- a/llm-service/app/routers/index/__init__.py +++ b/llm-service/app/routers/index/__init__.py @@ -40,7 +40,7 @@ from fastapi import APIRouter -from . import data_source, tools +from . import data_source, tools, custom_tools from . import sessions from . import summaries from . import amp_metadata @@ -62,3 +62,4 @@ router.include_router(models.router) router.include_router(metrics.router) router.include_router(tools.router) +router.include_router(custom_tools.router) diff --git a/llm-service/app/routers/index/custom_tools.py b/llm-service/app/routers/index/custom_tools.py new file mode 100644 index 000000000..477075a90 --- /dev/null +++ b/llm-service/app/routers/index/custom_tools.py @@ -0,0 +1,353 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# + +import logging +import os +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, File, Form, Header, HTTPException, UploadFile +from pydantic import BaseModel + +from app.config import settings +from app import exceptions + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/custom-tools", tags=["Custom Tools"]) + + +class UserToolCreateRequest(BaseModel): + """Request model for creating a user tool.""" + + name: str + display_name: str + description: str + + +class UserToolResponse(BaseModel): + """Response model for user tools.""" + + name: str + display_name: str + description: str + script_path: str + + +class UserToolTestRequest(BaseModel): + """Request model for testing a user tool.""" + + input_data: Dict[str, Any] + + +@router.get("", summary="Get user tools", response_model=List[UserToolResponse]) +@exceptions.propagates +def get_user_tools() -> List[UserToolResponse]: + """Get all tools for the current user.""" + try: + from app.services.query.agents.agent_tools.dynamic_mcp import UserToolStorage + + storage = UserToolStorage() + tools_data = storage.get_custom_tools() + + responses = [] + for tool in tools_data: + display_name, description = ( + tool["metadata"]["display_name"], + tool["metadata"]["description"], + ) + responses.append( + UserToolResponse( + name=tool["name"], + display_name=display_name, + description=description, + script_path=tool["script_path"], + ) + ) + return responses + except Exception as e: + logger.error(f"Error getting user tools: {e}") + raise HTTPException(status_code=500, detail=f"Error retrieving tools: {str(e)}") + + +@router.post("", summary="Create user tool", response_model=UserToolResponse) +@exceptions.propagates +def create_user_tool( + name: str = Form(...), + display_name: str = Form(...), + description: str = Form(...), + script_file: UploadFile = File(...), + origin_remote_user: Optional[str] = Header(None), +) -> UserToolResponse: + """Create a new user tool.""" + try: + from app.services.query.agents.agent_tools.dynamic_mcp import ( + UserToolDefinition, + UserToolStorage, + ) + + username = origin_remote_user or "default_user" + storage = UserToolStorage() + + # Check if tool already exists + existing_tool = storage.get_tool(username, name) + if existing_tool: + raise HTTPException(status_code=400, detail=f"Tool '{name}' already exists") + + # Validate file type + if not script_file.filename or not script_file.filename.endswith(".py"): + raise HTTPException( + status_code=400, detail="Script file must be a Python (.py) file" + ) + + # Read file content + file_content = script_file.file.read().decode("utf-8") + + # Save the script file and get the path + script_path = storage.save_script_file(name, file_content) + + # Create full path for validation + try: + full_script_path = os.path.join(settings.tools_dir, script_path) + except ImportError: + full_script_path = os.path.join("..", "tools", script_path) + + # Create and validate the tool + tool = UserToolDefinition( + name=name, + display_name=display_name, + description=description, + script_path=full_script_path, + ) + + # Save the tool + storage.save_tool(tool) + + return UserToolResponse( + name=tool.name, + display_name=tool.display_name, + description=tool.description, + script_path=script_path, # Return relative path + ) + + except ValueError as e: + # Validation errors from tool creation + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error creating user tool: {e}") + raise HTTPException(status_code=500, detail=f"Error creating tool: {str(e)}") + + +@router.get("/{tool_name}", summary="Get user tool", response_model=UserToolResponse) +@exceptions.propagates +def get_user_tool( + tool_name: str, origin_remote_user: Optional[str] = Header(None) +) -> UserToolResponse: + """Get a specific user tool.""" + try: + from app.services.query.agents.agent_tools.dynamic_mcp import UserToolStorage + + username = origin_remote_user or "default_user" + storage = UserToolStorage() + + tool_data = storage.get_tool(username, tool_name) + if not tool_data: + raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found") + + display_name, description = ( + tool_data["metadata"]["display_name"], + tool_data["metadata"]["description"], + ) + + return UserToolResponse( + name=tool_data["name"], + display_name=display_name, + description=description, + script_path=tool_data["script_path"], + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting user tool: {e}") + raise HTTPException(status_code=500, detail=f"Error retrieving tool: {str(e)}") + + +@router.put("/{tool_name}", summary="Update user tool", response_model=UserToolResponse) +@exceptions.propagates +def update_user_tool( + tool_name: str, + name: str = Form(...), + display_name: str = Form(...), + description: str = Form(...), + script_file: UploadFile = File(...), + origin_remote_user: Optional[str] = Header(None), +) -> UserToolResponse: + """Update an existing user tool.""" + try: + from app.services.query.agents.agent_tools.dynamic_mcp import ( + UserToolDefinition, + UserToolStorage, + ) + + username = origin_remote_user or "default_user" + storage = UserToolStorage() + + # Check if tool exists + existing_tool = storage.get_tool(username, tool_name) + if not existing_tool: + raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found") + + # Validate file type + if not script_file.filename or not script_file.filename.endswith(".py"): + raise HTTPException( + status_code=400, detail="Script file must be a Python (.py) file" + ) + + # Read file content + file_content = script_file.file.read().decode("utf-8") + + # Save the script file and get the path (this will overwrite the old file) + script_path = storage.save_script_file(name, file_content) + + # Create full path for validation + try: + full_script_path = os.path.join(settings.tools_dir, script_path) + except ImportError: + full_script_path = os.path.join("..", "tools", script_path) + + # Create and validate the updated tool + tool = UserToolDefinition( + name=name, + display_name=display_name, + description=description, + script_path=full_script_path, + ) + + # Save the updated tool + storage.save_tool(tool) + + return UserToolResponse( + name=tool.name, + display_name=tool.display_name, + description=tool.description, + script_path=script_path, # Return relative path + ) + + except ValueError as e: + # Validation errors from tool creation + raise HTTPException(status_code=400, detail=str(e)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating user tool: {e}") + raise HTTPException(status_code=500, detail=f"Error updating tool: {str(e)}") + + +@router.delete("/{tool_name}", summary="Delete user tool") +@exceptions.propagates +def delete_user_tool( + tool_name: str, origin_remote_user: Optional[str] = Header(None) +) -> Dict[str, str]: + """Delete a user tool.""" + try: + from app.services.query.agents.agent_tools.dynamic_mcp import UserToolStorage + + username = origin_remote_user or "default_user" + storage = UserToolStorage() + + success = storage.delete_tool(username, tool_name) + if not success: + raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found") + + return {"message": f"Tool '{tool_name}' deleted successfully"} + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error deleting user tool: {e}") + raise HTTPException(status_code=500, detail=f"Error deleting tool: {str(e)}") + + +@router.post("/{tool_name}/test", summary="Test user tool") +@exceptions.propagates +def test_user_tool( + tool_name: str, + request: UserToolTestRequest, + origin_remote_user: Optional[str] = Header(None), +) -> Dict[str, Any]: + """Test a user tool with provided input.""" + try: + from app.services.query.agents.agent_tools.dynamic_mcp import ( + UserToolStorage, + create_user_tool_from_dict, + ) + + username = origin_remote_user or "default_user" + storage = UserToolStorage() + + tool_data = storage.get_tool(username, tool_name) + if not tool_data: + raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found") + + # Create the tool and test it + tool = create_user_tool_from_dict(tool_data) + result = tool.execute(**request.input_data) + + return {"success": True, "result": result, "input": request.input_data} + + except ValueError as e: + return { + "success": False, + "error": f"Validation error: {str(e)}", + "input": request.input_data, + } + except RuntimeError as e: + return { + "success": False, + "error": f"Execution error: {str(e)}", + "input": request.input_data, + } + except HTTPException: + raise + except Exception as e: + logger.error(f"Error testing user tool: {e}") + return { + "success": False, + "error": f"Unexpected error: {str(e)}", + "input": request.input_data, + } diff --git a/llm-service/app/services/query/agents/agent_tools/dynamic_mcp.py b/llm-service/app/services/query/agents/agent_tools/dynamic_mcp.py new file mode 100644 index 000000000..f1491f3c7 --- /dev/null +++ b/llm-service/app/services/query/agents/agent_tools/dynamic_mcp.py @@ -0,0 +1,472 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# + +import ast +import json +import logging +import os +from typing import Any, Dict, List, Optional, Type, cast + +from llama_index.core.tools import FunctionTool +from pydantic import BaseModel, create_model +from app.config import settings + +logger = logging.getLogger(__name__) + + +class UserToolDefinition: + """ + Represents a user-submitted tool with its schema and code. + """ + + def __init__( + self, + name: str, + display_name: str, + description: str, + script_path: str, + ) -> None: + self.name = name + self.display_name = display_name + self.description = description + self.script_path = script_path + + # Validate and prepare the function + self._validate_script_path() + self._prepare_function() + self.function_schema = self._extract_function_schema() + + def _extract_function_schema(self) -> Dict[str, Any]: + """ + Extracts a JSON schema from the main function in the script file. + """ + with open(self.script_path, "r") as f: + function_code = f.read() + tree = ast.parse(function_code) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + func_node = node + break + else: + raise ValueError("No function definition found in script.") + + # Extract argument names, types, and docstring + properties = {} + required = [] + for arg in func_node.args.args: + if arg.arg == "self": + continue + arg_type = "string" # default type + if arg.annotation: + ann = ast.unparse(arg.annotation) + if ann in ["int", "float", "bool", "list", "dict"]: + arg_type = { + "int": "integer", + "float": "number", + "bool": "boolean", + "list": "array", + "dict": "object", + }[ann] + properties[arg.arg] = {"type": arg_type} + required.append(arg.arg) + docstring = ast.get_docstring(func_node) or "" + return { + "title": func_node.name, + "description": docstring, + "type": "object", + "properties": properties, + "required": required, + } + + def _validate_script_path(self) -> None: + """Validate that the script path exists and the code is safe to execute.""" + if not os.path.exists(self.script_path): + raise ValueError(f"Script file not found: {self.script_path}") + + try: + with open(self.script_path, "r") as f: + function_code = f.read() + # Parse the code to ensure it's valid Python + tree = ast.parse(function_code) + except SyntaxError as e: + raise ValueError(f"Invalid Python syntax in script: {e}") + except IOError as e: + raise ValueError(f"Error reading script file: {e}") + + # Security checks - disallow dangerous imports and operations + dangerous_patterns = [ + "import os", + "import subprocess", + "import sys", + "import socket", + "exec(", + "eval(", + "__import__", + "open(", + "file(", + "compile(", + "globals(", + "locals(", + "vars(", + ] + + for pattern in dangerous_patterns: + if pattern in function_code: + raise ValueError(f"Dangerous operation detected: {pattern}") + + # Ensure there's at least one function definition + function_names = [] + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + function_names.append(node.name) + + if not function_names: + raise ValueError("Script must contain at least one function definition") + + def _prepare_function(self) -> None: + """Prepare the function for execution.""" + # Read the code from the script file + with open(self.script_path, "r") as f: + function_code = f.read() + + # Extract the main function from the code + tree = ast.parse(function_code) + + # Find the first function definition + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + self.main_function_name = node.name + break + else: + raise ValueError("No function definition found") + + def _create_input_model(self) -> Type[BaseModel]: + """Create a Pydantic model from the function schema.""" + properties = self.function_schema.get("properties", {}) + required_fields = self.function_schema.get("required", []) + + # Convert JSON Schema types to Python types + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + } + + field_definitions: Dict[str, Any] = {} + for field_name, field_schema in properties.items(): + field_type = type_mapping.get(field_schema.get("type", "string"), str) + default_value = ... if field_name in required_fields else None + field_definitions[field_name] = (field_type, default_value) + + return cast( + Type[BaseModel], create_model(f"{self.name}Input", **field_definitions) + ) + + def execute(self, **kwargs: Any) -> Any: + """Execute the user's function with the provided arguments.""" + try: + # Create a restricted execution environment + safe_globals = { + "__builtins__": { + "len": len, + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, + "tuple": tuple, + "set": set, + "range": range, + "enumerate": enumerate, + "zip": zip, + "max": max, + "min": min, + "sum": sum, + "abs": abs, + "round": round, + "sorted": sorted, + "reversed": reversed, + "print": print, # Allow print for debugging + # Common exceptions + "ValueError": ValueError, + "TypeError": TypeError, + "KeyError": KeyError, + "IndexError": IndexError, + "ZeroDivisionError": ZeroDivisionError, + "AttributeError": AttributeError, + "RuntimeError": RuntimeError, + } + } + + # Read and execute the function code from script file + with open(self.script_path, "r") as f: + function_code = f.read() + exec(function_code, safe_globals) + + # Get the function from the executed environment + user_function = safe_globals.get(self.main_function_name) + if not user_function: + raise RuntimeError( + f"Function '{self.main_function_name}' not found after execution" + ) + + # Ensure it's callable + if not callable(user_function): + raise RuntimeError( + f"'{self.main_function_name}' is not a callable function" + ) + + # Call the function with the provided arguments + # We've verified user_function is callable + result = user_function(**kwargs) + + return result + + except Exception as e: + logger.error(f"Error executing user tool '{self.name}': {e}") + raise RuntimeError(f"Tool execution failed: {e}") + + def to_function_tool(self) -> FunctionTool: + """Convert this user tool to a LlamaIndex FunctionTool.""" + + # Create the input model + input_model = self._create_input_model() + + # Create the function tool + def tool_function(**kwargs: Any) -> Any: + try: + result = self.execute(**kwargs) + # Ensure result is JSON serializable + if isinstance(result, (str, int, float, bool, list, dict)): + return result + else: + return str(result) + except Exception as e: + return f"Error: {str(e)}" + + return FunctionTool.from_defaults( + fn=tool_function, + name=self.name, + description=self.description, + fn_schema=input_model, + ) + + +class UserToolStorage: + """ + Unified storage for user tools in mcp.json file. + """ + + def __init__(self) -> None: + # Use the tools directory from settings + try: + self.mcp_json_path = os.path.join(settings.tools_dir, "mcp.json") + self.scripts_dir = os.path.join(settings.tools_dir, "custom_tool_scripts") + except ImportError: + self.mcp_json_path = os.path.join("..", "tools", "mcp.json") + self.scripts_dir = os.path.join("..", "tools", "custom_tool_scripts") + + # Ensure scripts directory exists + os.makedirs(self.scripts_dir, exist_ok=True) + + def _read_mcp_config(self) -> Dict[str, Any]: + """Read the entire mcp.json configuration.""" + if not os.path.exists(self.mcp_json_path): + return {"mcp_servers": [], "custom_tools": []} + + try: + with open(self.mcp_json_path, "r") as f: + config = cast(Dict[str, Any], json.load(f)) + # Ensure custom_tools array exists + if "custom_tools" not in config: + config["custom_tools"] = [] + return config + except (json.JSONDecodeError, IOError) as e: + logger.error(f"Error reading mcp.json: {e}") + return {"mcp_servers": [], "custom_tools": []} + + def _write_mcp_config(self, config: Dict[str, Any]) -> None: + """Write the entire mcp.json configuration.""" + try: + with open(self.mcp_json_path, "w") as f: + json.dump(config, f, indent=2) + except IOError as e: + logger.error(f"Error writing mcp.json: {e}") + raise RuntimeError(f"Failed to save tool configuration: {e}") + + def save_script_file(self, tool_name: str, file_content: str) -> str: + """Save a Python script file and return the relative path.""" + script_filename = f"{tool_name}.py" + script_path = os.path.join(self.scripts_dir, script_filename) + + with open(script_path, "w") as f: + f.write(file_content) + + # Return relative path for storage in mcp.json + return os.path.join("custom_tool_scripts", script_filename) + + def save_tool(self, tool: UserToolDefinition) -> None: + """Save a user tool to mcp.json.""" + config = self._read_mcp_config() + + tool_data = { + "name": tool.name, + "metadata": { + "display_name": tool.display_name, + "description": tool.description, + }, + "function_schema": tool.function_schema, + "script_path": tool.script_path, + } + + # Remove existing tool with same name + config["custom_tools"] = [ + t for t in config["custom_tools"] if t.get("name") != tool.name + ] + config["custom_tools"].append(tool_data) + + # Save to mcp.json + self._write_mcp_config(config) + + def get_custom_tools(self, username: Optional[str] = None) -> List[Dict[str, Any]]: + """Get all custom tools (username parameter ignored for unified storage).""" + config = self._read_mcp_config() + return cast(List[Dict[str, Any]], config.get("custom_tools", [])) + + def get_tool(self, username: str, tool_name: str) -> Optional[Dict[str, Any]]: + """Get a specific tool (username parameter ignored for unified storage).""" + tools = self.get_custom_tools() + for tool in tools: + if tool.get("name") == tool_name: + return tool + return None + + def delete_tool(self, username: str, tool_name: str) -> bool: + """Delete a tool (username parameter ignored for unified storage).""" + config = self._read_mcp_config() + + # Find the tool to get its script path + tool_to_delete = None + for tool in config["custom_tools"]: + if tool.get("name") == tool_name: + tool_to_delete = tool + break + + if not tool_to_delete: + return False # Tool not found + + # Remove the tool from config + config["custom_tools"] = [ + t for t in config["custom_tools"] if t.get("name") != tool_name + ] + + # Delete the script file if it exists + if "script_path" in tool_to_delete: + try: + script_full_path = os.path.join( + settings.tools_dir, tool_to_delete["script_path"] + ) + if os.path.exists(script_full_path): + os.remove(script_full_path) + except (ImportError, OSError) as e: + logger.warning(f"Could not delete script file: {e}") + + # Save updated config + self._write_mcp_config(config) + return True + + +def create_user_tool_from_dict(tool_data: Dict[str, Any]) -> UserToolDefinition: + """Create a UserToolDefinition from a dictionary.""" + # Convert relative script path to absolute path + script_path = tool_data["script_path"] + if not os.path.isabs(script_path): + try: + script_path = os.path.join(settings.tools_dir, script_path) + except ImportError: + script_path = os.path.join("..", "tools", script_path) + + # Handle both old and new metadata structure for backward compatibility + if "metadata" in tool_data: + # New structure with metadata + display_name = tool_data["metadata"]["display_name"] + description = tool_data["metadata"]["description"] + else: + # Old structure for backward compatibility + display_name = tool_data.get("display_name", tool_data["name"]) + description = tool_data.get("description", "") + + return UserToolDefinition( + name=tool_data["name"], + display_name=display_name, + description=description, + script_path=script_path, + ) + + +def get_custom_function_tools(username: Optional[str] = None) -> List[FunctionTool]: + """Get all FunctionTools for custom user-submitted tools. + + Args: + username: Ignored in unified storage (kept for API compatibility). + + Returns: + List of FunctionTool objects for all custom tools. + """ + storage = UserToolStorage() + tools_data = storage.get_custom_tools() + + function_tools = [] + for tool_data in tools_data: + try: + user_tool = create_user_tool_from_dict(tool_data) + function_tool = user_tool.to_function_tool() + function_tools.append(function_tool) + except Exception as e: + logger.error( + "Error creating function tool from %s: %s", + tool_data.get("name", "unknown"), + e, + ) + continue + + return function_tools diff --git a/llm-service/app/services/query/agents/tool_calling_querier.py b/llm-service/app/services/query/agents/tool_calling_querier.py index ed4f92c05..93b97fd55 100644 --- a/llm-service/app/services/query/agents/tool_calling_querier.py +++ b/llm-service/app/services/query/agents/tool_calling_querier.py @@ -204,11 +204,29 @@ def stream_chat( if session.query_configuration and session.query_configuration.selected_tools: for tool_name in session.query_configuration.selected_tools: try: + # Try to load as static MCP tool first mcp_tools.extend(get_llama_index_tools(tool_name)) except ValueError as e: logger.warning(f"Could not create adapter for tool {tool_name}: {e}") continue + # Also load user-submitted tools for this session + # For now, we'll load tools for a default user - this would need to be + # updated to get the actual user from the session context + try: + from app.services.query.agents.agent_tools.dynamic_mcp import ( + get_custom_function_tools, + ) + + # TODO: Get actual username from session context + custom_tools = get_custom_function_tools() + print(f"Loaded {len(custom_tools)} custom tools") + mcp_tools.extend(custom_tools) + except ImportError as e: + logger.warning(f"Could not load user tools: {e}") + except Exception as e: + logger.warning(f"Error loading user tools: {e}") + # Use the existing chat engine with the enhanced query for streaming response tools: list[BaseTool] = mcp_tools # Use tool calling only if retrieval is not the only tool to optimize performance diff --git a/llm-service/examples/README.md b/llm-service/examples/README.md new file mode 100644 index 000000000..9f8991b45 --- /dev/null +++ b/llm-service/examples/README.md @@ -0,0 +1,167 @@ +# User Tools Examples + +This directory contains examples demonstrating how to use the RAG Studio User Tools system, which allows users to submit custom Python functions that get wrapped into MCP servers and can be used in chat sessions. + +## Overview + +The User Tools system enables users to: + +- Submit custom Python functions with JSON schema definitions +- Test their tools before using them in chat +- Have their tools automatically validated for security +- Use their tools alongside built-in MCP tools in RAG Studio chat sessions + +## Calculator Tool Example + +The `calculator_tool_example.py` demonstrates a simple arithmetic calculator tool that can perform basic operations (add, subtract, multiply, divide). + +### Running the Example + +1. **Start the RAG Studio backend** (make sure it's running on `http://localhost:8000`) + +2. **Run the calculator example:** + + ```bash + cd llm-service/examples + python calculator_tool_example.py --username your_username + ``` + +3. **Or run specific actions:** + + ```bash + # Just submit the tool + python calculator_tool_example.py --action submit --username your_username + + # Just test the tool + python calculator_tool_example.py --action test --username your_username + + # Just list tools + python calculator_tool_example.py --action list --username your_username + ``` + +### What the Example Does + +1. **Submits** a calculator tool with: + + - Name: `simple_calculator` + - Inputs: `first_number`, `second_number`, `operation` + - Function: Python code that performs arithmetic operations + +2. **Tests** the tool with sample calculations: + + - 10 + 5 = 15 + - 10 - 3 = 7 + - 7 ร— 6 = 42 + - 15 รท 3 = 5 + +3. **Lists** all tools for the user + +## Creating Your Own Tools + +### Tool Definition Format + +Each user tool consists of: + +```python +{ + "name": "tool_name", # Unique identifier (alphanumeric + underscores) + "display_name": "Human Readable Name", # Name shown in UI + "description": "What this tool does", # Description for LLM and users + "function_schema": { # JSON Schema for inputs + "type": "object", + "properties": { + "param1": { + "type": "string", + "description": "First parameter" + }, + "param2": { + "type": "number", + "description": "Second parameter" + } + }, + "required": ["param1", "param2"] + }, + "function_code": '''def tool_name(param1: str, param2: float) -> str: + """Your function implementation here""" + return f"Result: {param1} + {param2}"''' +} +``` + +### Function Requirements + +Your Python function must: + +1. **Match the schema**: Function parameters must match the `function_schema` properties +2. **Be self-contained**: No external imports (except safe builtins like `len`, `str`, `int`, etc.) +3. **Be secure**: No file system access, network calls, or dangerous operations +4. **Have type hints**: Input and return types should be specified +5. **Include docstring**: Describe what the function does + +### Security Restrictions + +For security, user functions cannot: + +- Import modules (`import os`, `import subprocess`, etc.) +- Access files (`open()`, `file()`) +- Execute code (`exec()`, `eval()`, `compile()`) +- Access system internals (`globals()`, `locals()`, `__import__`) +- Make network calls (`socket`, `urllib`, etc.) + +### Supported Types + +Function parameters and return values can be: + +- Basic types: `str`, `int`, `float`, `bool` +- Collections: `list`, `dict` +- Optional types: `Optional[str]`, etc. +- Union types: `Union[str, int]`, etc. + +## API Endpoints + +The user tools system provides these REST endpoints: + +- `GET /user-tools` - List all tools for authenticated user +- `POST /user-tools` - Submit a new tool +- `GET /user-tools/{tool_name}` - Get specific tool details +- `PUT /user-tools/{tool_name}` - Update existing tool +- `DELETE /user-tools/{tool_name}` - Delete tool +- `POST /user-tools/{tool_name}/test` - Test tool with sample inputs + +## Authentication + +All endpoints require the `origin_remote_user` header to identify the user. Each user's tools are isolated from other users. + +## Using Tools in Chat + +Once submitted and tested, your tools automatically become available in RAG Studio chat sessions. The LLM can call your tools just like any built-in tool when they're relevant to the conversation. + +To use a tool in chat: + +1. Submit your tool via the API or UI +2. Start a chat session in RAG Studio +3. Ask questions that would benefit from your tool +4. The LLM will automatically call your tool when appropriate + +## Example Tools You Could Create + +- **Unit Converter**: Convert between different units (meters to feet, celsius to fahrenheit) +- **Text Processor**: Count words, reverse text, format strings +- **Math Helper**: Calculate percentages, compound interest, geometric formulas +- **Data Validator**: Check email formats, phone numbers, credit card numbers +- **Code Generator**: Generate SQL queries, regular expressions, HTML snippets +- **Business Logic**: Calculate taxes, shipping costs, discount pricing + +## Troubleshooting + +### Common Issues + +1. **Tool submission fails**: Check that your function syntax is valid Python +2. **Security validation fails**: Remove any imports or dangerous operations +3. **Schema mismatch**: Ensure function parameters match the JSON schema exactly +4. **Tool not appearing in chat**: Verify the tool was submitted successfully via the list endpoint + +### Getting Help + +- Check the backend logs for detailed error messages +- Use the test endpoint to debug your function logic +- Ensure your JSON schema validates with online schema validators diff --git a/llm-service/examples/calculator_tool_demo.py b/llm-service/examples/calculator_tool_demo.py new file mode 100644 index 000000000..73c0362c3 --- /dev/null +++ b/llm-service/examples/calculator_tool_demo.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +""" +Standalone Calculator Tool Demo + +This script demonstrates the calculator tool definition and testing +without requiring the RAG Studio API to be running. +""" + +import json + +# This is the exact tool definition format used in the user tools system +CALCULATOR_TOOL_DEFINITION = { + "name": "simple_calculator", + "display_name": "Simple Calculator", + "description": "Performs basic arithmetic operations on two numbers", + "function_schema": { + "type": "object", + "properties": { + "first_number": { + "type": "number", + "description": "The first number in the operation", + }, + "second_number": { + "type": "number", + "description": "The second number in the operation", + }, + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform", + }, + }, + "required": ["first_number", "second_number", "operation"], + }, + "function_code": '''def simple_calculator(first_number: float, second_number: float, operation: str) -> float: + """ + Performs basic arithmetic operations on two numbers. + + Args: + first_number: The first number in the operation + second_number: The second number in the operation + operation: The arithmetic operation to perform (add, subtract, multiply, divide) + + Returns: + The result of the arithmetic operation + + Raises: + ValueError: If operation is not supported or division by zero + """ + if operation == "add": + return first_number + second_number + elif operation == "subtract": + return first_number - second_number + elif operation == "multiply": + return first_number * second_number + elif operation == "divide": + if second_number == 0: + raise ValueError("Cannot divide by zero") + return first_number / second_number + else: + raise ValueError(f"Unsupported operation: {operation}")''', +} + + +def execute_function_code(function_code: str, function_name: str, **kwargs): + """ + Execute user function code safely (simulates the dynamic_mcp execution). + + This is similar to how the actual UserToolDefinition.execute() method works. + """ + # Create a restricted globals environment (similar to dynamic_mcp.py) + safe_globals = { + "__builtins__": { + "len": len, + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, + "tuple": tuple, + "set": set, + "min": min, + "max": max, + "sum": sum, + "abs": abs, + "round": round, + "range": range, + "enumerate": enumerate, + "zip": zip, + "isinstance": isinstance, + "type": type, + # Common exceptions + "ValueError": ValueError, + "TypeError": TypeError, + "KeyError": KeyError, + "IndexError": IndexError, + "ZeroDivisionError": ZeroDivisionError, + "AttributeError": AttributeError, + "RuntimeError": RuntimeError, + } + } + + # Execute the function code in restricted environment + local_vars = {} + exec(function_code, safe_globals, local_vars) + + # Get the function and call it + if function_name in local_vars: + func = local_vars[function_name] + return func(**kwargs) + else: + raise ValueError(f"Function {function_name} not found in code") + + +def demo_calculator(): + """Demonstrate the calculator tool definition and usage.""" + + print("๐Ÿ”ง Calculator Tool Demo") + print("=" * 50) + + # Show the tool definition + print("\n๐Ÿ“‹ Tool Definition:") + print(f"Name: {CALCULATOR_TOOL_DEFINITION['name']}") + print(f"Display Name: {CALCULATOR_TOOL_DEFINITION['display_name']}") + print(f"Description: {CALCULATOR_TOOL_DEFINITION['description']}") + + print("\n๐Ÿ“„ Function Schema:") + schema = CALCULATOR_TOOL_DEFINITION["function_schema"] + print(json.dumps(schema, indent=2)) + + print("\n๐Ÿ Function Code:") + print(CALCULATOR_TOOL_DEFINITION["function_code"]) + + print("\n" + "=" * 50) + print("๐Ÿงช Testing Calculator Tool") + print("=" * 50) + + # Test cases + test_cases = [ + {"first_number": 10, "second_number": 5, "operation": "add", "expected": 15}, + { + "first_number": 10, + "second_number": 3, + "operation": "subtract", + "expected": 7, + }, + { + "first_number": 7, + "second_number": 6, + "operation": "multiply", + "expected": 42, + }, + {"first_number": 15, "second_number": 3, "operation": "divide", "expected": 5}, + { + "first_number": 10, + "second_number": 0, + "operation": "divide", + "expected": "ERROR", + }, + ] + + function_code = CALCULATOR_TOOL_DEFINITION["function_code"] + function_name = CALCULATOR_TOOL_DEFINITION["name"] + + for i, test_case in enumerate(test_cases, 1): + expected = test_case.pop("expected") + + print( + f"\nTest {i}: {test_case['first_number']} {test_case['operation']} {test_case['second_number']}" + ) + + try: + result = execute_function_code(function_code, function_name, **test_case) + print(f" Result: {result}") + + if expected == "ERROR": + print(f" โŒ Expected error but got result: {result}") + elif result == expected: + print(f" โœ… Correct! Expected {expected}") + else: + print(f" โŒ Wrong! Expected {expected}, got {result}") + + except Exception as e: + print(f" Error: {e}") + if expected == "ERROR": + print(f" โœ… Correct! Expected an error") + else: + print(f" โŒ Unexpected error! Expected {expected}") + + +def show_api_usage(): + """Show how this tool would be used with the actual API.""" + + print("\n" + "=" * 50) + print("๐ŸŒ API Usage Example") + print("=" * 50) + + print("\n1. Submit the tool:") + print("```bash") + print("curl -X POST http://localhost:8000/user-tools \\") + print(" -H 'Content-Type: application/json' \\") + print(" -H 'origin_remote_user: your_username' \\") + print(" -d '{") + print(' "name": "simple_calculator",') + print(' "display_name": "Simple Calculator",') + print(' "description": "Performs basic arithmetic operations on two numbers",') + print(' "function_schema": { ... },') + print(' "function_code": "def simple_calculator(...): ..."') + print(" }'") + print("```") + + print("\n2. Test the tool:") + print("```bash") + print("curl -X POST http://localhost:8000/user-tools/simple_calculator/test \\") + print(" -H 'Content-Type: application/json' \\") + print(" -H 'origin_remote_user: your_username' \\") + print(" -d '{") + print(' "input_data": {') + print(' "first_number": 10,') + print(' "second_number": 5,') + print(' "operation": "add"') + print(" }") + print(" }'") + print("```") + + print("\n3. Use in chat:") + print( + "Once submitted, the tool is automatically available in RAG Studio chat sessions." + ) + print("The LLM can call it when users ask mathematical questions:") + print(' User: "What is 15 divided by 3?"') + print(' Assistant: *calls simple_calculator(15, 3, "divide")* โ†’ "The result is 5"') + + +def show_more_examples(): + """Show examples of other tools that could be created.""" + + print("\n" + "=" * 50) + print("๐Ÿ’ก More Tool Ideas") + print("=" * 50) + + examples = [ + { + "name": "text_counter", + "description": "Count words, characters, and lines in text", + "example_inputs": {"text": "Hello world!", "count_type": "words"}, + "example_output": 2, + }, + { + "name": "temperature_converter", + "description": "Convert between Celsius, Fahrenheit, and Kelvin", + "example_inputs": { + "temperature": 100, + "from_unit": "celsius", + "to_unit": "fahrenheit", + }, + "example_output": 212.0, + }, + { + "name": "percentage_calculator", + "description": "Calculate percentages, percentage change, etc.", + "example_inputs": { + "value": 80, + "total": 200, + "calculation": "percentage_of", + }, + "example_output": 40.0, + }, + { + "name": "string_formatter", + "description": "Format strings (uppercase, lowercase, title case, etc.)", + "example_inputs": {"text": "hello world", "format_type": "title"}, + "example_output": "Hello World", + }, + ] + + for example in examples: + print(f"\n๐Ÿ“ {example['name']}") + print(f" Description: {example['description']}") + print(f" Example Input: {example['example_inputs']}") + print(f" Example Output: {example['example_output']}") + + +if __name__ == "__main__": + demo_calculator() + show_api_usage() + show_more_examples() + + print("\n" + "=" * 50) + print("๐ŸŽ‰ Demo Complete!") + print("=" * 50) + print("\nTo actually use this tool:") + print("1. Start the RAG Studio backend") + print("2. Run: python calculator_tool_example.py --username your_username") + print("3. Or use the API directly with curl/requests") + print("4. Once submitted, use the tool in RAG Studio chat sessions") diff --git a/llm-service/examples/calculator_tool_example.py b/llm-service/examples/calculator_tool_example.py new file mode 100644 index 000000000..75cb02f60 --- /dev/null +++ b/llm-service/examples/calculator_tool_example.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +""" +Example: Simple Calculator Tool Submission + +This script demonstrates how to submit a user tool using the new user tools API. +The calculator tool performs basic arithmetic operations (add, subtract, multiply, divide). +""" + +import requests + +# Define the calculator tool +CALCULATOR_TOOL = { + "name": "simple_calculator", + "display_name": "Simple Calculator", + "description": "Performs basic arithmetic operations on two numbers", + "function_schema": { + "type": "object", + "properties": { + "first_number": { + "type": "number", + "description": "The first number in the operation", + }, + "second_number": { + "type": "number", + "description": "The second number in the operation", + }, + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform", + }, + }, + "required": ["first_number", "second_number", "operation"], + }, + "function_code": '''def simple_calculator(first_number: float, second_number: float, operation: str) -> float: + """ + Performs basic arithmetic operations on two numbers. + + Args: + first_number: The first number in the operation + second_number: The second number in the operation + operation: The arithmetic operation to perform (add, subtract, multiply, divide) + + Returns: + The result of the arithmetic operation + + Raises: + ValueError: If operation is not supported or division by zero + """ + if operation == "add": + return first_number + second_number + elif operation == "subtract": + return first_number - second_number + elif operation == "multiply": + return first_number * second_number + elif operation == "divide": + if second_number == 0: + raise ValueError("Cannot divide by zero") + return first_number / second_number + else: + raise ValueError(f"Unsupported operation: {operation}")''', +} + + +def submit_calculator_tool( + api_base_url: str = "http://localhost:8000", username: str = "example_user" +): + """ + Submit the calculator tool to the user tools API. + + Args: + api_base_url: Base URL of the RAG Studio API + username: Username to submit the tool under + """ + url = f"{api_base_url}/custom-tools" + headers = {"Content-Type": "application/json", "origin_remote_user": username} + + try: + response = requests.post(url, json=CALCULATOR_TOOL, headers=headers) + response.raise_for_status() + + print(f"โœ… Successfully submitted calculator tool!") + print(f"Response: {response.json()}") + + return True + + except requests.exceptions.RequestException as e: + print(f"โŒ Failed to submit calculator tool: {e}") + if hasattr(e, "response") and e.response is not None: + print(f"Response body: {e.response.text}") + return False + + +def test_calculator_tool( + api_base_url: str = "http://localhost:8000", username: str = "example_user" +): + """ + Test the calculator tool using the test endpoint. + + Args: + api_base_url: Base URL of the RAG Studio API + username: Username that owns the tool + """ + url = f"{api_base_url}/user-tools/simple_calculator/test" + headers = {"Content-Type": "application/json", "origin_remote_user": username} + + # Test cases + test_cases = [ + {"first_number": 10, "second_number": 5, "operation": "add"}, + {"first_number": 10, "second_number": 3, "operation": "subtract"}, + {"first_number": 7, "second_number": 6, "operation": "multiply"}, + {"first_number": 15, "second_number": 3, "operation": "divide"}, + ] + + print("\n๐Ÿงช Testing calculator tool...") + + for i, test_case in enumerate(test_cases, 1): + try: + response = requests.post( + url, json={"input_data": test_case}, headers=headers + ) + response.raise_for_status() + + result = response.json() + expected_results = {"add": 15, "subtract": 7, "multiply": 42, "divide": 5} + + operation = test_case["operation"] + expected = expected_results.get(operation) + actual = result.get("result") + + print( + f" Test {i}: {test_case['first_number']} {operation} {test_case['second_number']} = {actual}" + ) + + if expected == actual: + print(f" โœ… Correct! Expected {expected}") + else: + print(f" โŒ Wrong! Expected {expected}, got {actual}") + + except requests.exceptions.RequestException as e: + print(f" โŒ Test {i} failed: {e}") + if hasattr(e, "response") and e.response is not None: + print(f" Response: {e.response.text}") + + +def list_custom_tools( + api_base_url: str = "http://localhost:8000", username: str = "example_user" +): + """ + List all custom tools for the given username. + """ + url = f"{api_base_url}/custom-tools" + headers = {"origin_remote_user": username} + + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + + tools = response.json() + print(f"\n๐Ÿ“‹ Custom tools for {username}:") + + if not tools: + print(" No tools found") + else: + for tool in tools: + print(f" - {tool['name']}: {tool['display_name']}") + print(f" Description: {tool['description']}") + + return tools + + except requests.exceptions.RequestException as e: + print(f"โŒ Failed to list user tools: {e}") + return [] + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Calculator tool example for RAG Studio user tools" + ) + parser.add_argument( + "--api-url", default="http://localhost:8000", help="API base URL" + ) + parser.add_argument("--username", default="example_user", help="Username to use") + parser.add_argument( + "--action", + choices=["submit", "test", "list", "all"], + default="all", + help="Action to perform", + ) + + args = parser.parse_args() + + print(f"๐Ÿ”ง Calculator Tool Example") + print(f"API URL: {args.api_url}") + print(f"Username: {args.username}") + print("-" * 50) + + if args.action in ["submit", "all"]: + submit_calculator_tool(args.api_url, args.username) + + if args.action in ["test", "all"]: + test_calculator_tool(args.api_url, args.username) + + if args.action in ["list", "all"]: + list_custom_tools(args.api_url, args.username) diff --git a/llm-service/examples/simple_calculator_tool.json b/llm-service/examples/simple_calculator_tool.json new file mode 100644 index 000000000..99477cb97 --- /dev/null +++ b/llm-service/examples/simple_calculator_tool.json @@ -0,0 +1,25 @@ +{ + "name": "simple_calculator", + "display_name": "Simple Calculator", + "description": "Performs basic arithmetic operations on two numbers", + "function_schema": { + "type": "object", + "properties": { + "first_number": { + "type": "number", + "description": "The first number in the operation" + }, + "second_number": { + "type": "number", + "description": "The second number in the operation" + }, + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform" + } + }, + "required": ["first_number", "second_number", "operation"] + }, + "function_code": "def simple_calculator(first_number: float, second_number: float, operation: str) -> float:\n \"\"\"\n Performs basic arithmetic operations on two numbers.\n \n Args:\n first_number: The first number in the operation\n second_number: The second number in the operation \n operation: The arithmetic operation to perform (add, subtract, multiply, divide)\n \n Returns:\n The result of the arithmetic operation\n \n Raises:\n ValueError: If operation is not supported or division by zero\n \"\"\"\n if operation == \"add\":\n return first_number + second_number\n elif operation == \"subtract\":\n return first_number - second_number\n elif operation == \"multiply\":\n return first_number * second_number\n elif operation == \"divide\":\n if second_number == 0:\n raise ValueError(\"Cannot divide by zero\")\n return first_number / second_number\n else:\n raise ValueError(f\"Unsupported operation: {operation}\")" +} diff --git a/tools/custom_tool_scripts/simple-calculator.py b/tools/custom_tool_scripts/simple-calculator.py new file mode 100644 index 000000000..06e8709ca --- /dev/null +++ b/tools/custom_tool_scripts/simple-calculator.py @@ -0,0 +1,27 @@ +def simple_calculator(first_number: float, second_number: float, operation: str) -> float: + """ + Performs basic arithmetic operations on two numbers. + + Args: + first_number: The first number in the operation + second_number: The second number in the operation + operation: The arithmetic operation to perform (add, subtract, multiply, divide) + + Returns: + The result of the arithmetic operation + + Raises: + ValueError: If operation is not supported or division by zero + """ + if operation == "add": + return first_number + second_number + elif operation == "subtract": + return first_number - second_number + elif operation == "multiply": + return first_number * second_number + elif operation == "divide": + if second_number == 0: + raise ValueError("Cannot divide by zero") + return first_number / second_number + else: + raise ValueError(f"Unsupported operation: {operation}") \ No newline at end of file diff --git a/tools/custom_tool_scripts/simple_calculator.py b/tools/custom_tool_scripts/simple_calculator.py new file mode 100644 index 000000000..06e8709ca --- /dev/null +++ b/tools/custom_tool_scripts/simple_calculator.py @@ -0,0 +1,27 @@ +def simple_calculator(first_number: float, second_number: float, operation: str) -> float: + """ + Performs basic arithmetic operations on two numbers. + + Args: + first_number: The first number in the operation + second_number: The second number in the operation + operation: The arithmetic operation to perform (add, subtract, multiply, divide) + + Returns: + The result of the arithmetic operation + + Raises: + ValueError: If operation is not supported or division by zero + """ + if operation == "add": + return first_number + second_number + elif operation == "subtract": + return first_number - second_number + elif operation == "multiply": + return first_number * second_number + elif operation == "divide": + if second_number == 0: + raise ValueError("Cannot divide by zero") + return first_number / second_number + else: + raise ValueError(f"Unsupported operation: {operation}") \ No newline at end of file diff --git a/tools/mcp.json b/tools/mcp.json index d5b9f834d..dcbafead7 100644 --- a/tools/mcp.json +++ b/tools/mcp.json @@ -7,9 +7,36 @@ "display_name": "Web Scraper", "description": "Extracts and retrieves content from web pages. This tool enables you to incorporate live web data into your queries, enhancing responses with up-to-date information from specified URLs." }, - "args": [ - "mcp-server-fetch" - ] + "args": ["mcp-server-fetch"] + } + ], + "custom_tools": [ + { + "name": "simple-calculator", + "metadata": { + "display_name": "Simple Calculator", + "description": "Performs basic arithmetic operations on two numbers" + }, + "function_schema": { + "type": "object", + "properties": { + "first_number": { + "type": "number", + "description": "The first number in the operation" + }, + "second_number": { + "type": "number", + "description": "The second number in the operation" + }, + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform" + } + }, + "required": ["first_number", "second_number", "operation"] + }, + "script_path": "../tools/custom_tool_scripts/simple-calculator.py" } ] -} \ No newline at end of file +} diff --git a/ui/src/api/toolsApi.ts b/ui/src/api/toolsApi.ts index 6177ad106..77644fec5 100644 --- a/ui/src/api/toolsApi.ts +++ b/ui/src/api/toolsApi.ts @@ -56,6 +56,14 @@ export interface Tool { description: string; display_name: string; }; + type?: "mcp" | "custom"; // Add tool type to distinguish MCP vs custom tools +} + +export interface CustomTool { + name: string; + display_name: string; + description: string; + script_path: string; } export interface AddToolFormValues { @@ -68,6 +76,17 @@ export interface AddToolFormValues { description: string; } +export interface CustomToolFormValues { + name: string; + display_name: string; + description: string; + script_file: File; +} + +export interface CustomToolTestRequest { + input_data: Record; +} + export const getTools = async (): Promise => { return getRequest(`${llmServicePath}/tools`); }; @@ -120,3 +139,165 @@ export const useDeleteToolMutation = ({ onError, }); }; + +// User Tools API +export const getCustomTools = async (): Promise => { + return getRequest(`${llmServicePath}/custom-tools`); +}; + +export const useCustomToolsQuery = () => { + return useQuery({ + queryKey: [QueryKeys.getCustomTools], + queryFn: getCustomTools, + }); +}; + +export const createCustomTool = async (toolData: { + name: string; + display_name: string; + description: string; + script_file: File; +}): Promise => { + const formData = new FormData(); + formData.append("name", toolData.name); + formData.append("display_name", toolData.display_name); + formData.append("description", toolData.description); + formData.append("script_file", toolData.script_file); + + const response = await fetch(`${llmServicePath}/custom-tools`, { + method: "POST", + body: formData, + }); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status.toString()}`); + } + + return response.json() as Promise; +}; + +export const useCreateCustomToolMutation = ({ + onSuccess, + onError, +}: UseMutationType) => { + const queryClient = useQueryClient(); + return useMutation({ + mutationFn: createCustomTool, + onSuccess: (tool) => { + void queryClient.invalidateQueries({ + queryKey: [QueryKeys.getCustomTools], + }); + if (onSuccess) { + onSuccess(tool); + } + }, + onError, + }); +}; + +export const updateCustomTool = async ( + toolName: string, + toolData: { + name: string; + display_name: string; + description: string; + script_file: File; + } +): Promise => { + const formData = new FormData(); + formData.append("name", toolData.name); + formData.append("display_name", toolData.display_name); + formData.append("description", toolData.description); + formData.append("script_file", toolData.script_file); + + const response = await fetch(`${llmServicePath}/custom-tools/${toolName}`, { + method: "PUT", + body: formData, + }); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status.toString()}`); + } + + return response.json() as Promise; +}; + +export const useUpdateCustomToolMutation = ({ + onSuccess, + onError, +}: UseMutationType) => { + const queryClient = useQueryClient(); + return useMutation({ + mutationFn: ({ + toolName, + toolData, + }: { + toolName: string; + toolData: { + name: string; + display_name: string; + description: string; + script_file: File; + }; + }) => updateCustomTool(toolName, toolData), + onSuccess: (tool) => { + void queryClient.invalidateQueries({ + queryKey: [QueryKeys.getCustomTools], + }); + if (onSuccess) { + onSuccess(tool); + } + }, + onError, + }); +}; + +export const deleteCustomTool = async (toolName: string): Promise => { + return deleteRequest(`${llmServicePath}/custom-tools/${toolName}`); +}; + +export const useDeleteCustomToolMutation = ({ + onSuccess, + onError, +}: UseMutationType) => { + const queryClient = useQueryClient(); + return useMutation({ + mutationFn: deleteCustomTool, + onSuccess: () => { + void queryClient.invalidateQueries({ + queryKey: [QueryKeys.getCustomTools], + }); + if (onSuccess) { + onSuccess(); + } + }, + onError, + }); +}; + +export const testCustomTool = async ( + toolName: string, + testData: CustomToolTestRequest +): Promise<{ success: boolean; result?: unknown; error?: string }> => { + return postRequest( + `${llmServicePath}/custom-tools/${toolName}/test`, + testData + ); +}; + +export const useTestCustomToolMutation = ({ + onSuccess, + onError, +}: UseMutationType<{ success: boolean; result?: unknown; error?: string }>) => { + return useMutation({ + mutationFn: ({ + toolName, + testData, + }: { + toolName: string; + testData: CustomToolTestRequest; + }) => testCustomTool(toolName, testData), + onSuccess, + onError, + }); +}; diff --git a/ui/src/api/utils.ts b/ui/src/api/utils.ts index 5e8e93453..06736e464 100644 --- a/ui/src/api/utils.ts +++ b/ui/src/api/utils.ts @@ -82,6 +82,10 @@ export enum MutationKeys { "restartApplication" = "restartApplication", "streamChatMutation" = "streamChatMutation", "setCdpToken" = "setCdpToken", + "createCustomTool" = "createCustomTool", + "updateCustomTool" = "updateCustomTool", + "deleteCustomTool" = "deleteCustomTool", + "testCustomTool" = "testCustomTool", } export enum QueryKeys { @@ -113,6 +117,8 @@ export enum QueryKeys { "getAmpConfig" = "getAmpConfig", "getTools" = "getTools", "getPollingAmpConfig" = "getPollingAmpConfig", + "getCustomTools" = "getCustomTools", + "getCustomTool" = "getCustomTool", } export const commonHeaders = { @@ -133,7 +139,7 @@ export interface CustomError { export class ApiError extends Error { constructor( message = "unknown", - public status: number, + public status: number ) { super(message); this.name = "CustomError"; @@ -143,7 +149,7 @@ export class ApiError extends Error { export const postRequest = async ( url: string, - body: Record, + body: Record ): Promise => { const res = await fetch(url, { method: "POST", @@ -174,6 +180,25 @@ export const getRequest = async (url: string): Promise => { return await ((await res.json()) as Promise); }; +export const putRequest = async ( + url: string, + body: Record +): Promise => { + const res = await fetch(url, { + method: "PUT", + body: JSON.stringify(body), + headers: { + ...commonHeaders, + "Content-Type": "application/json", + }, + }); + if (!res.ok) { + const detail = (await res.json()) as CustomError; + throw new ApiError(detail.message ?? detail.detail, res.status); + } + return await ((await res.json()) as Promise); +}; + export const deleteRequest = async (url: string) => { const res = await fetch(url, { method: "DELETE", diff --git a/ui/src/pages/RagChatTab/FooterComponents/ToolsManager.tsx b/ui/src/pages/RagChatTab/FooterComponents/ToolsManager.tsx index cdc5e021e..eb7fdfdb8 100644 --- a/ui/src/pages/RagChatTab/FooterComponents/ToolsManager.tsx +++ b/ui/src/pages/RagChatTab/FooterComponents/ToolsManager.tsx @@ -47,7 +47,7 @@ import { Tooltip, Typography, } from "antd"; -import { useToolsQuery } from "src/api/toolsApi.ts"; +import { useToolsQuery, useCustomToolsQuery } from "src/api/toolsApi.ts"; import { Dispatch, ReactNode, @@ -70,14 +70,27 @@ import { Link } from "@tanstack/react-router"; import { getAmpConfigQueryOptions } from "src/api/ampMetadataApi.ts"; const ToolsManagerContent = ({ activeSession }: { activeSession: Session }) => { - const { data, isLoading } = useToolsQuery(); + const { data: mcpTools, isLoading: mcpLoading } = useToolsQuery(); + const { data: customTools, isLoading: customLoading } = useCustomToolsQuery(); const { data: config } = useSuspenseQuery(getAmpConfigQueryOptions); - const toolsList = data?.map((tool) => ({ - name: tool.name, - displayName: tool.metadata.display_name, - description: tool.metadata.description, - })); + // Combine MCP tools and custom tools into a unified list + const toolsList = [ + ...(mcpTools?.map((tool) => ({ + name: tool.name, + displayName: tool.metadata.display_name, + description: tool.metadata.description, + type: "MCP" as const, + })) ?? []), + ...(customTools?.map((tool) => ({ + name: tool.name, + displayName: tool.display_name, + description: tool.description, + type: "User" as const, + })) ?? []), + ]; + + const isLoading = mcpLoading || customLoading; const queryClient = useQueryClient(); @@ -113,8 +126,8 @@ const ToolsManagerContent = ({ activeSession }: { activeSession: Session }) => { } else { handleUpdateSession( activeSession.queryConfiguration.selectedTools.filter( - (tool) => tool !== title, - ), + (tool) => tool !== title + ) ); } }; @@ -153,12 +166,30 @@ const ToolsManagerContent = ({ activeSession }: { activeSession: Session }) => { renderItem={(item) => ( + {item.displayName || item.name} + + {item.type} + + + } description={item.description} avatar={ { handleCheck(item.name, e.target.checked); diff --git a/ui/src/pages/Tools/AddNewToolModal.tsx b/ui/src/pages/Tools/AddNewToolModal.tsx index cdeea1d2f..37b6c741d 100644 --- a/ui/src/pages/Tools/AddNewToolModal.tsx +++ b/ui/src/pages/Tools/AddNewToolModal.tsx @@ -36,15 +36,31 @@ * DATA. */ -import { Button, Flex, Form, Input, Modal, Space, Typography } from "antd"; +import { + Button, + Flex, + Form, + FormInstance, + Input, + Modal, + Space, + Typography, + Upload, +} from "antd"; import { AddToolFormValues, + CustomToolFormValues, Tool, useAddToolMutation, + useCreateCustomToolMutation, } from "src/api/toolsApi.ts"; import messageQueue from "src/utils/messageQueue.ts"; import { useState } from "react"; -import { MinusCircleOutlined, PlusOutlined } from "@ant-design/icons"; +import { + InboxOutlined, + MinusCircleOutlined, + PlusOutlined, +} from "@ant-design/icons"; const CommandFormFields = () => { return ( @@ -125,6 +141,49 @@ const UrlFormFields = () => { ); }; +const UserToolFormFields = ({ + form, +}: { + form: FormInstance; +}) => { + return ( + <> + + false} // Prevent auto upload + accept=".py" + maxCount={1} + onChange={(info) => { + const file = info.fileList[0]?.originFileObj; + // Set the file in the form + if (file) { + form.setFieldsValue({ script_file: file }); + } + }} + style={{ padding: "20px" }} + > +

+ +

+

+ Click or drag Python file (.py) to this area to upload +

+

+ Support for a single Python script file. The file will be used to + create your custom tool. +

+
+
+ + ); +}; + export const AddNewToolModal = ({ isModalVisible, setIsModalVisible, @@ -132,52 +191,91 @@ export const AddNewToolModal = ({ isModalVisible: boolean; setIsModalVisible: (visible: boolean) => void; }) => { - const [form] = Form.useForm(); - const [toolType, setToolType] = useState<"command" | "url">("command"); + const [form] = Form.useForm(); + const [toolType, setToolType] = useState<"command" | "url" | "custom">( + "command", + ); + const addToolMutation = useAddToolMutation({ onSuccess: () => { - messageQueue.success("Tool added successfully"); + messageQueue.success("MCP tool added successfully"); + setIsModalVisible(false); + form.resetFields(); + }, + onError: (error) => { + messageQueue.error(`Failed to add MCP tool: ${error.message}`); + }, + }); + + const createCustomToolMutation = useCreateCustomToolMutation({ + onSuccess: () => { + messageQueue.success("Custom tool added successfully"); setIsModalVisible(false); form.resetFields(); }, onError: (error) => { - messageQueue.error(`Failed to add tool: ${error.message}`); + messageQueue.error(`Failed to add custom tool: ${error.message}`); }, }); const handleAddTool = () => { void form.validateFields().then((values) => { - const newTool: Tool = { - name: values.name, - metadata: { + if (toolType === "custom") { + // Handle custom tool creation + const customToolData = { + name: values.name, display_name: values.display_name, description: values.description, - }, - }; - - if (toolType === "command") { - newTool.command = values.command; - if (values.args) { - newTool.args = values.args.split(",").map((arg) => arg.trim()); - } - if (values.env?.length) { - newTool.env = values.env.reduce((accum, val) => { - return { ...accum, [val.key]: val.value }; - }, {}); - } + script_file: values.script_file, + }; + createCustomToolMutation.mutate(customToolData); } else { - if (values.url) { - newTool.url = values.url.split(",").map((url) => url.trim()); + // Handle MCP tool creation + const newTool: Tool = { + name: values.name, + metadata: { + display_name: values.display_name, + description: values.description, + }, + }; + + if (toolType === "command") { + newTool.command = values.command; + if (values.args) { + newTool.args = values.args.split(",").map((arg) => arg.trim()); + } + if (values.env?.length) { + newTool.env = values.env.reduce((accum, val) => { + return { ...accum, [val.key]: val.value }; + }, {}); + } + } else { + if (values.url) { + newTool.url = values.url.split(",").map((url) => url.trim()); + } } - } - addToolMutation.mutate(newTool); + addToolMutation.mutate(newTool); + } }); }; + const getModalTitle = () => { + switch (toolType) { + case "command": + return "Add MCP Command Tool"; + case "url": + return "Add MCP URL Tool"; + case "custom": + return "Add Custom Function Tool"; + default: + return "Add New Tool"; + } + }; + return ( { setIsModalVisible(false); @@ -197,9 +295,11 @@ export const AddNewToolModal = ({ key="submit" type="primary" onClick={handleAddTool} - loading={addToolMutation.isPending} + loading={ + addToolMutation.isPending || createCustomToolMutation.isPending + } > - Add + {toolType === "custom" ? "Create Custom Tool" : "Add MCP Tool"} , ]} > @@ -246,7 +346,7 @@ export const AddNewToolModal = ({ setToolType("command"); }} > - Command-based + MCP Command + - {toolType === "command" ? : } + {toolType === "command" && } + {toolType === "url" && } + {toolType === "custom" && } ); diff --git a/ui/src/pages/Tools/ToolsPage.tsx b/ui/src/pages/Tools/ToolsPage.tsx index e7a6b480e..50770cf2e 100644 --- a/ui/src/pages/Tools/ToolsPage.tsx +++ b/ui/src/pages/Tools/ToolsPage.tsx @@ -49,29 +49,97 @@ import { } from "antd"; import { DeleteOutlined, PlusOutlined } from "@ant-design/icons"; import { - Tool, useDeleteToolMutation, + useDeleteCustomToolMutation, useToolsQuery, + useCustomToolsQuery, } from "src/api/toolsApi.ts"; import messageQueue from "src/utils/messageQueue.ts"; import useModal from "src/utils/useModal.ts"; import { AddNewToolModal } from "pages/Tools/AddNewToolModal.tsx"; +interface UnifiedTool { + name: string; + display_name: string; + description: string; + type: "mcp" | "custom"; +} + const ToolsPage = () => { const confirmDeleteModal = useModal(); - const { data: tools = [], isLoading, error: toolsError } = useToolsQuery(); + const { + data: mcpTools = [], + isLoading: mcpLoading, + error: mcpError, + } = useToolsQuery(); + const { + data: customTools = [], + isLoading: customLoading, + error: customError, + } = useCustomToolsQuery(); const [isModalVisible, setIsModalVisible] = useState(false); + const [toolToDelete, setToolToDelete] = useState(null); + + // Transform data into unified format + const unifiedTools: UnifiedTool[] = [ + ...mcpTools.map( + (tool): UnifiedTool => ({ + name: tool.name, + display_name: tool.metadata.display_name, + description: tool.metadata.description, + type: "mcp", + }) + ), + ...customTools.map( + (tool): UnifiedTool => ({ + name: tool.name, + display_name: tool.display_name, + description: tool.description, + type: "custom", + }) + ), + ]; + + const isLoading = mcpLoading || customLoading; + const toolsError = mcpError ?? customError; const deleteToolMutation = useDeleteToolMutation({ onSuccess: () => { - messageQueue.success("Tool deleted successfully"); + messageQueue.success("MCP tool deleted successfully"); + confirmDeleteModal.setIsModalOpen(false); + setToolToDelete(null); + }, + onError: (error) => { + messageQueue.error(`Failed to delete MCP tool: ${error.message}`); + }, + }); + + const deleteCustomToolMutation = useDeleteCustomToolMutation({ + onSuccess: () => { + messageQueue.success("Custom tool deleted successfully"); confirmDeleteModal.setIsModalOpen(false); + setToolToDelete(null); }, onError: (error) => { - messageQueue.error(`Failed to delete tool: ${error.message}`); + messageQueue.error(`Failed to delete custom tool: ${error.message}`); }, }); + const handleDeleteTool = (tool: UnifiedTool) => { + setToolToDelete(tool); + confirmDeleteModal.setIsModalOpen(true); + }; + + const confirmDelete = () => { + if (!toolToDelete) return; + + if (toolToDelete.type === "mcp") { + deleteToolMutation.mutate(toolToDelete.name); + } else { + deleteCustomToolMutation.mutate(toolToDelete.name); + } + }; + const columns = [ { title: "Internal Name", @@ -80,43 +148,27 @@ const ToolsPage = () => { }, { title: "Display Name", - dataIndex: ["metadata", "display_name"], + dataIndex: "display_name", key: "display_name", }, { title: "Description", - dataIndex: ["metadata", "description"], + dataIndex: "description", key: "description", }, { title: "Actions", key: "actions", width: 80, - render: (_: unknown, tool: Tool) => ( - <> - { /> )} + + + { + confirmDeleteModal.handleCancel(); + setToolToDelete(null); + }} + > + Are you sure you want to delete the tool "{toolToDelete?.display_name} + "? This action cannot be undone. + );