diff --git a/src/posit/connect/_utils.py b/src/posit/connect/_utils.py index d9c1b083..b4d02b5c 100644 --- a/src/posit/connect/_utils.py +++ b/src/posit/connect/_utils.py @@ -1,9 +1,11 @@ from __future__ import annotations -import os +import warnings from typing_extensions import Any +from ..environment import is_local as env_is_local + def update_dict_values(obj: dict[str, Any], /, **kwargs: Any) -> None: """ @@ -33,10 +35,15 @@ def update_dict_values(obj: dict[str, Any], /, **kwargs: Any) -> None: def is_local() -> bool: - """Returns true if called from a piece of content running on a Connect server. + """ + Check if code is running locally. - The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`. - We can use this environment variable to determine if the content is running locally - or on a Connect server. + .. deprecated:: 0.9.0 + Use :func:`posit.environment.is_local` instead. """ - return os.getenv("RSTUDIO_PRODUCT") != "CONNECT" + warnings.warn( + "posit.connect._utils.is_local is deprecated. Use posit.environment.is_local instead.", + DeprecationWarning, + stacklevel=2, + ) + return env_is_local() diff --git a/src/posit/connect/client.py b/src/posit/connect/client.py index 1804caec..8665c9a0 100644 --- a/src/posit/connect/client.py +++ b/src/posit/connect/client.py @@ -208,6 +208,7 @@ def with_user_session_token(self, token: str) -> Client: -------- ```python from posit.connect import Client + client = Client().with_user_session_token("my-user-session-token") ``` @@ -218,13 +219,14 @@ def with_user_session_token(self, token: str) -> Client: client = Client() + @reactive.calc def visitor_client(): ## read the user session token and generate a new client - user_session_token = session.http_conn.headers.get( - "Posit-Connect-User-Session-Token" - ) + user_session_token = session.http_conn.headers.get("Posit-Connect-User-Session-Token") return client.with_user_session_token(user_session_token) + + @render.text def user_profile(): # fetch the viewer's profile information diff --git a/src/posit/connect/external/databricks.py b/src/posit/connect/external/databricks.py index 1f3b0895..93500a21 100644 --- a/src/posit/connect/external/databricks.py +++ b/src/posit/connect/external/databricks.py @@ -13,7 +13,7 @@ import requests from typing_extensions import Callable, Dict, Optional -from .._utils import is_local +from ...environment import is_local from ..client import Client from ..oauth import Credentials diff --git a/src/posit/connect/external/snowflake.py b/src/posit/connect/external/snowflake.py index c40c188d..82ea9561 100644 --- a/src/posit/connect/external/snowflake.py +++ b/src/posit/connect/external/snowflake.py @@ -9,7 +9,7 @@ from typing_extensions import Optional -from .._utils import is_local +from ...environment import is_local from ..client import Client diff --git a/src/posit/environment.py b/src/posit/environment.py new file mode 100644 index 00000000..42be7f1e --- /dev/null +++ b/src/posit/environment.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import os + + +def get_product() -> str | None: + """Returns the product name if called with a Posit product. + + The products will always set the environment variable `POSIT_PRODUCT=` + or `RSTUDIO_PRODUCT=`. + + RSTUDIO_PRODUCT is deprecated and acts as a fallback for backwards compatibility. + It is recommended to use POSIT_PRODUCT instead. + """ + return os.getenv("POSIT_PRODUCT") or os.getenv("RSTUDIO_PRODUCT") + + +def is_local() -> bool: + """Returns true if called while running locally.""" + return get_product() is None + + +def is_running_on_connect() -> bool: + """Returns true if called from a piece of content running on a Connect server.""" + return get_product() == "CONNECT" + + +def is_running_on_workbench() -> bool: + """Returns true if called from within a Workbench server.""" + return get_product() == "WORKBENCH" diff --git a/tests/posit/connect/external/test_databricks.py b/tests/posit/connect/external/test_databricks.py index 134911b1..c692f235 100644 --- a/tests/posit/connect/external/test_databricks.py +++ b/tests/posit/connect/external/test_databricks.py @@ -87,7 +87,7 @@ def test_new_bearer_authorization_header(self): def test_get_auth_type_local(self): assert _get_auth_type("local-auth") == "local-auth" - @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + @patch.dict("os.environ", {"POSIT_PRODUCT": "CONNECT"}) def test_get_auth_type_connect(self): assert _get_auth_type("local-auth") == POSIT_OAUTH_INTEGRATION_AUTH_TYPE @@ -176,7 +176,7 @@ def test_local_content_credentials_strategy(self): @patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"}) @responses.activate - @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + @patch.dict("os.environ", {"POSIT_PRODUCT": "CONNECT"}) def test_posit_content_credentials_strategy(self): register_mocks() @@ -191,7 +191,7 @@ def test_posit_content_credentials_strategy(self): assert cp() == {"Authorization": "Bearer content-access-token"} @responses.activate - @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + @patch.dict("os.environ", {"POSIT_PRODUCT": "CONNECT"}) def test_posit_credentials_strategy(self): register_mocks() diff --git a/tests/posit/connect/external/test_snowflake.py b/tests/posit/connect/external/test_snowflake.py index 3d2c5414..91e4e5dc 100644 --- a/tests/posit/connect/external/test_snowflake.py +++ b/tests/posit/connect/external/test_snowflake.py @@ -28,7 +28,7 @@ def register_mocks(): class TestPositAuthenticator: @responses.activate - @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + @patch.dict("os.environ", {"POSIT_PRODUCT": "CONNECT"}) def test_posit_authenticator(self): register_mocks() diff --git a/tests/posit/connect/test_client.py b/tests/posit/connect/test_client.py index be6fe9f9..e556f3be 100644 --- a/tests/posit/connect/test_client.py +++ b/tests/posit/connect/test_client.py @@ -85,7 +85,7 @@ def test_init( MockSession.assert_called_once() @responses.activate - @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + @patch.dict("os.environ", {"POSIT_PRODUCT": "CONNECT"}) def test_with_user_session_token(self): api_key = "12345" url = "https://connect.example.com" @@ -117,7 +117,7 @@ def test_with_user_session_token(self): assert visitor_client.cfg.api_key == "api-key" @responses.activate - @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + @patch.dict("os.environ", {"POSIT_PRODUCT": "CONNECT"}) def test_with_user_session_token_bad_exchange_response_body(self): api_key = "12345" url = "https://connect.example.com" @@ -143,7 +143,7 @@ def test_with_user_session_token_bad_exchange_response_body(self): client.with_user_session_token("cit") assert str(err.value) == "Unable to retrieve token." - @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + @patch.dict("os.environ", {"POSIT_PRODUCT": "CONNECT"}) def test_with_user_session_token_bad_token_deployed(self): api_key = "12345" url = "https://connect.example.com" diff --git a/tests/posit/test_environment.py b/tests/posit/test_environment.py new file mode 100644 index 00000000..2c970420 --- /dev/null +++ b/tests/posit/test_environment.py @@ -0,0 +1,54 @@ +# pyright: reportFunctionMemberAccess=false +from unittest.mock import patch + +import pytest + +from posit.environment import ( + get_product, + is_local, + is_running_on_connect, + is_running_on_workbench, +) + + +@pytest.mark.parametrize( + ("posit_product", "rstudio_product", "expected"), + [ + ("CONNECT", None, "CONNECT"), + (None, "WORKBENCH", "WORKBENCH"), + ("CONNECT", "WORKBENCH", "CONNECT"), + (None, None, None), + ], +) +def test_get_product(posit_product, rstudio_product, expected): + env = {} + if posit_product is not None: + env["POSIT_PRODUCT"] = posit_product + if rstudio_product is not None: + env["RSTUDIO_PRODUCT"] = rstudio_product + with patch.dict("os.environ", env, clear=True): + assert get_product() == expected + + +def test_is_local(): + with patch("posit.environment.get_product", return_value=None): + assert is_local() is True + + with patch("posit.environment.get_product", return_value="CONNECT"): + assert is_local() is False + + +def test_is_running_on_connect(): + with patch("posit.environment.get_product", return_value="CONNECT"): + assert is_running_on_connect() is True + + with patch("posit.environment.get_product", return_value="WORKBENCH"): + assert is_running_on_connect() is False + + +def test_is_running_on_workbench(): + with patch("posit.environment.get_product", return_value="WORKBENCH"): + assert is_running_on_workbench() is True + + with patch("posit.environment.get_product", return_value="CONNECT"): + assert is_running_on_workbench() is False