diff --git a/stagehand/config.py b/stagehand/config.py index a577230..9557bf2 100644 --- a/stagehand/config.py +++ b/stagehand/config.py @@ -20,6 +20,7 @@ class StagehandConfig(BaseModel): browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions. model_name (Optional[str]): Name of the model to use. model_api_key (Optional[str]): Model API key. + model_client_options (Optional[dict[str, Any]]): Options for the model client. logger (Optional[Callable[[Any], None]]): Custom logging function. verbose (Optional[int]): Verbosity level for logs (1=minimal, 2=medium, 3=detailed). use_rich_logging (bool): Whether to use Rich for colorized logging. @@ -50,6 +51,9 @@ class StagehandConfig(BaseModel): model_api_key: Optional[str] = Field( None, alias="modelApiKey", description="Model API key" ) + model_client_options: Optional[dict[str, Any]] = Field( + None, alias="modelClientOptions", description="Configuration options for the language model client (i.e. apiKey, baseURL)", + ) verbose: Optional[int] = Field( 1, description="Verbosity level for logs: 0=minimal (ERROR), 1=medium (INFO), 2=detailed (DEBUG)", diff --git a/stagehand/llm/client.py b/stagehand/llm/client.py index 855b0fe..9d681d5 100644 --- a/stagehand/llm/client.py +++ b/stagehand/llm/client.py @@ -54,7 +54,7 @@ def __init__( setattr(litellm, key, value) self.logger.debug(f"Set global litellm.{key}", category="llm") # Handle common aliases or expected config names if necessary - elif key == "api_base": # Example: map api_base if needed + elif key == "api_base" or key == "baseURL": # Example: map api_base if needed litellm.api_base = value self.logger.debug( f"Set global litellm.api_base to {value}", category="llm" diff --git a/stagehand/main.py b/stagehand/main.py index 0de682e..0f98906 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -68,7 +68,11 @@ def __init__( # Handle non-config parameters self.api_url = self.config.api_url - self.model_api_key = self.config.model_api_key or os.getenv("MODEL_API_KEY") + + # Handle model-related settings + self.model_client_options = self.config.model_client_options or {} + self.model_api_key = self.config.model_api_key or self.model_client_options.get("apiKey") or os.getenv("MODEL_API_KEY") + self.model_name = self.config.model_name # Extract frequently used values from config for convenience @@ -89,11 +93,6 @@ def __init__( self.config.local_browser_launch_options or {} ) - # Handle model-related settings - self.model_client_options = {} - if self.model_api_key and "apiKey" not in self.model_client_options: - self.model_client_options["apiKey"] = self.model_api_key - # Handle browserbase session create params self.browserbase_session_create_params = make_serializable( self.config.browserbase_session_create_params diff --git a/tests/unit/llm/test_llm_integration.py b/tests/unit/llm/test_llm_integration.py index a01e7a7..00d09c4 100644 --- a/tests/unit/llm/test_llm_integration.py +++ b/tests/unit/llm/test_llm_integration.py @@ -40,6 +40,7 @@ def test_llm_client_with_custom_options(self): api_key="test-key", default_model="gpt-4o-mini", stagehand_logger=StagehandLogger(), + api_base="https://test-api-base.com", ) assert client.default_model == "gpt-4o-mini" diff --git a/tests/unit/test_client_api.py b/tests/unit/test_client_api.py index f6cb20b..e76e30d 100644 --- a/tests/unit/test_client_api.py +++ b/tests/unit/test_client_api.py @@ -19,7 +19,7 @@ async def mock_client(self): browserbase_session_id="test-session-123", api_key="test-api-key", project_id="test-project-id", - model_api_key="test-model-api-key", + model_client_options={"apiKey": "test-model-api-key"} ) return client diff --git a/tests/unit/test_client_initialization.py b/tests/unit/test_client_initialization.py index cd748ac..ff22039 100644 --- a/tests/unit/test_client_initialization.py +++ b/tests/unit/test_client_initialization.py @@ -23,7 +23,7 @@ def test_init_with_direct_params(self): browserbase_session_id="test-session", api_key="test-api-key", project_id="test-project-id", - model_api_key="test-model-api-key", + model_client_options={"apiKey": "test-model-api-key"}, verbose=2, ) @@ -203,3 +203,32 @@ async def mock_create_session(): # Call _create_session and expect error with pytest.raises(RuntimeError, match="Invalid response format"): await client._create_session() + + @mock.patch.dict(os.environ, {"MODEL_API_KEY": "test-model-api-key"}, clear=True) + def test_init_with_model_api_key_in_env(self): + config = StagehandConfig(env="LOCAL") + client = Stagehand(config=config) + assert client.model_api_key == "test-model-api-key" + + def test_init_with_custom_llm(self): + config = StagehandConfig( + env="LOCAL", + model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"} + ) + client = Stagehand(config=config) + assert client.model_api_key == "custom-llm-key" + assert client.model_client_options["apiKey"] == "custom-llm-key" + assert client.model_client_options["baseURL"] == "https://custom-llm.com" + + def test_init_with_custom_llm_override(self): + config = StagehandConfig( + env="LOCAL", + model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"} + ) + client = Stagehand( + config=config, + model_client_options={"apiKey": "override-llm-key", "baseURL": "https://override-llm.com"} + ) + assert client.model_api_key == "override-llm-key" + assert client.model_client_options["apiKey"] == "override-llm-key" + assert client.model_client_options["baseURL"] == "https://override-llm.com" \ No newline at end of file