Skip to content

Commit e567d5b

Browse files
authored
fix: invoke_with_tracing breaks rai used outside of developer space (#677)
1 parent fbf056f commit e567d5b

File tree

6 files changed

+326
-4
lines changed

6 files changed

+326
-4
lines changed

src/rai_core/rai/agents/langchain/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
create_react_runnable,
2020
create_state_based_runnable,
2121
)
22+
from .invocation_helpers import invoke_llm_with_tracing
2223
from .react_agent import ReActAgent
2324
from .state_based_agent import BaseStateBasedAgent, StateBasedConfig
2425

@@ -32,5 +33,6 @@
3233
"StateBasedConfig",
3334
"create_react_runnable",
3435
"create_state_based_runnable",
36+
"invoke_llm_with_tracing",
3537
"newMessageBehaviorType",
3638
]

src/rai_core/rai/agents/langchain/invocation_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def invoke_llm_with_tracing(
3535
This function automatically adds tracing callbacks (like Langfuse) to LLM calls
3636
within LangGraph nodes, solving the callback propagation issue.
3737
38+
Tracing is controlled by config.toml. If the file is missing, no tracing is applied.
39+
3840
Parameters
3941
----------
4042
llm : BaseChatModel

src/rai_core/rai/initialization/model_initialization.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,16 @@ def get_embeddings_model(
273273

274274

275275
def get_tracing_callbacks(
276-
override_use_langfuse: bool = False, override_use_langsmith: bool = False
276+
config_path: Optional[str] = None,
277277
) -> List[BaseCallbackHandler]:
278-
config = load_config()
278+
try:
279+
config = load_config(config_path)
280+
except Exception as e:
281+
logger.warning(f"Failed to load config for tracing: {e}, tracing disabled")
282+
return []
283+
279284
callbacks: List[BaseCallbackHandler] = []
280-
if config.tracing.langfuse.use_langfuse or override_use_langfuse:
285+
if config.tracing.langfuse.use_langfuse:
281286
from langfuse.callback import CallbackHandler # type: ignore
282287

283288
public_key = os.getenv("LANGFUSE_PUBLIC_KEY", None)
@@ -292,7 +297,7 @@ def get_tracing_callbacks(
292297
)
293298
callbacks.append(callback)
294299

295-
if config.tracing.langsmith.use_langsmith or override_use_langsmith:
300+
if config.tracing.langsmith.use_langsmith:
296301
os.environ["LANGCHAIN_TRACING_V2"] = "true"
297302
os.environ["LANGCHAIN_PROJECT"] = config.tracing.project
298303
api_key = os.getenv("LANGCHAIN_API_KEY", None)

tests/agents/langchain/test_langchain_agent.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from collections import deque
1617
from typing import List
18+
from unittest.mock import MagicMock, patch
1719

1820
import pytest
21+
from rai.agents.langchain import invoke_llm_with_tracing
1922
from rai.agents.langchain.agent import LangChainAgent, newMessageBehaviorType
23+
from rai.initialization import get_tracing_callbacks
2024

2125

2226
@pytest.mark.parametrize(
@@ -39,3 +43,110 @@ def test_reduce_messages(
3943
output_ = LangChainAgent._apply_reduction_behavior(new_message_behavior, buffer)
4044
assert output == output_
4145
assert buffer == deque(out_buffer)
46+
47+
48+
class TestTracingConfiguration:
49+
"""Test tracing configuration integration with langchain agents."""
50+
51+
def test_tracing_with_missing_config_file(self):
52+
"""Test that tracing gracefully handles missing config.toml file in langchain context."""
53+
# This should not crash even without config.toml
54+
callbacks = get_tracing_callbacks()
55+
assert len(callbacks) == 0
56+
57+
def test_tracing_with_config_file_present(self, test_config_toml):
58+
"""Test that tracing works when config.toml is present in langchain context."""
59+
config_path, cleanup = test_config_toml(
60+
langfuse_enabled=True, langsmith_enabled=False
61+
)
62+
63+
try:
64+
# Mock environment variables to avoid actual API calls
65+
with patch.dict(
66+
os.environ,
67+
{
68+
"LANGFUSE_PUBLIC_KEY": "test_key",
69+
"LANGFUSE_SECRET_KEY": "test_secret",
70+
},
71+
):
72+
callbacks = get_tracing_callbacks(config_path=config_path)
73+
# Should return 1 callback for langfuse
74+
assert len(callbacks) == 1
75+
finally:
76+
cleanup()
77+
78+
79+
class TestInvokeLLMWithTracing:
80+
"""Test the invoke_llm_with_tracing function."""
81+
82+
def test_invoke_llm_without_tracing(self):
83+
"""Test that invoke_llm_with_tracing works when no tracing callbacks are available."""
84+
# Mock LLM
85+
mock_llm = MagicMock()
86+
mock_llm.invoke.return_value = "test response"
87+
88+
# Mock messages
89+
mock_messages = ["test message"]
90+
91+
# Mock get_tracing_callbacks to return empty list (no config.toml)
92+
with patch(
93+
"rai.agents.langchain.invocation_helpers.get_tracing_callbacks"
94+
) as mock_get_callbacks:
95+
mock_get_callbacks.return_value = []
96+
97+
result = invoke_llm_with_tracing(mock_llm, mock_messages)
98+
99+
mock_llm.invoke.assert_called_once_with(mock_messages, config=None)
100+
assert result == "test response"
101+
102+
def test_invoke_llm_with_tracing(self):
103+
"""Test that invoke_llm_with_tracing works when tracing callbacks are available."""
104+
# Mock LLM
105+
mock_llm = MagicMock()
106+
mock_llm.invoke.return_value = "test response"
107+
108+
# Mock messages
109+
mock_messages = ["test message"]
110+
111+
# Mock get_tracing_callbacks to return some callbacks
112+
with patch(
113+
"rai.agents.langchain.invocation_helpers.get_tracing_callbacks"
114+
) as mock_get_callbacks:
115+
mock_get_callbacks.return_value = ["tracing_callback"]
116+
117+
_ = invoke_llm_with_tracing(mock_llm, mock_messages)
118+
119+
# Verify that the LLM was called with enhanced config
120+
mock_llm.invoke.assert_called_once()
121+
call_args = mock_llm.invoke.call_args
122+
assert call_args[0][0] == mock_messages
123+
assert "callbacks" in call_args[1]["config"]
124+
assert "tracing_callback" in call_args[1]["config"]["callbacks"]
125+
126+
def test_invoke_llm_with_existing_config(self):
127+
"""Test that invoke_llm_with_tracing preserves existing config."""
128+
# Mock LLM
129+
mock_llm = MagicMock()
130+
mock_llm.invoke.return_value = "test response"
131+
132+
# Mock messages
133+
mock_messages = ["test message"]
134+
135+
# Mock existing config
136+
existing_config = {"callbacks": ["existing_callback"]}
137+
138+
# Mock get_tracing_callbacks to return some callbacks
139+
with patch(
140+
"rai.agents.langchain.invocation_helpers.get_tracing_callbacks"
141+
) as mock_get_callbacks:
142+
mock_get_callbacks.return_value = ["tracing_callback"]
143+
144+
_ = invoke_llm_with_tracing(mock_llm, mock_messages, existing_config)
145+
146+
# Verify that the LLM was called with enhanced config
147+
mock_llm.invoke.assert_called_once()
148+
call_args = mock_llm.invoke.call_args
149+
assert call_args[0][0] == mock_messages
150+
assert "callbacks" in call_args[1]["config"]
151+
assert "existing_callback" in call_args[1]["config"]["callbacks"]
152+
assert "tracing_callback" in call_args[1]["config"]["callbacks"]

tests/conftest.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,132 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
import os
16+
import tempfile
17+
18+
import pytest
19+
20+
21+
@pytest.fixture
22+
def test_config_toml():
23+
"""
24+
Fixture to create a temporary test config.toml file with tracing enabled.
25+
26+
Returns
27+
-------
28+
tuple
29+
(config_path, cleanup_function) - The path to the config file and a function to clean it up
30+
"""
31+
32+
def _create_config(langfuse_enabled=False, langsmith_enabled=False):
33+
# Create a temporary config.toml file
34+
f = tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False)
35+
36+
# Base config sections (always required)
37+
config_content = """[vendor]
38+
simple_model = "openai"
39+
complex_model = "openai"
40+
embeddings_model = "text-embedding-ada-002"
41+
42+
[aws]
43+
simple_model = "anthropic.claude-instant-v1"
44+
complex_model = "anthropic.claude-v2"
45+
embeddings_model = "amazon.titan-embed-text-v1"
46+
region_name = "us-east-1"
47+
48+
[openai]
49+
simple_model = "gpt-3.5-turbo"
50+
complex_model = "gpt-4"
51+
embeddings_model = "text-embedding-ada-002"
52+
base_url = "https://api.openai.com/v1"
53+
54+
[ollama]
55+
simple_model = "llama2"
56+
complex_model = "llama2"
57+
embeddings_model = "llama2"
58+
base_url = "http://localhost:11434"
59+
60+
[tracing]
61+
project = "test-project"
62+
63+
[tracing.langfuse]
64+
use_langfuse = {langfuse_enabled}
65+
host = "http://localhost:3000"
66+
67+
[tracing.langsmith]
68+
use_langsmith = {langsmith_enabled}
69+
host = "https://api.smith.langchain.com"
70+
""".format(
71+
langfuse_enabled=str(langfuse_enabled).lower(),
72+
langsmith_enabled=str(langsmith_enabled).lower(),
73+
)
74+
75+
f.write(config_content)
76+
f.close()
77+
78+
def cleanup():
79+
try:
80+
f.close() # Ensure file is properly closed
81+
os.unlink(f.name)
82+
except (OSError, PermissionError):
83+
pass # File might already be deleted or have permission issues
84+
85+
return f.name, cleanup
86+
87+
return _create_config
88+
89+
90+
@pytest.fixture
91+
def test_config_no_tracing():
92+
"""
93+
Fixture to create a temporary test config.toml file with no tracing section.
94+
95+
Returns
96+
-------
97+
tuple
98+
(config_path, cleanup_function) - The path to the config file and a function to clean it up
99+
"""
100+
101+
def _create_config():
102+
# Create a temporary config.toml file
103+
f = tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False)
104+
105+
# Base config sections (always required)
106+
config_content = """[vendor]
107+
simple_model = "openai"
108+
complex_model = "openai"
109+
embeddings_model = "text-embedding-ada-002"
110+
111+
[aws]
112+
simple_model = "anthropic.claude-instant-v1"
113+
complex_model = "anthropic.claude-v2"
114+
embeddings_model = "amazon.titan-embed-text-v1"
115+
region_name = "us-east-1"
116+
117+
[openai]
118+
simple_model = "gpt-3.5-turbo"
119+
complex_model = "gpt-4"
120+
embeddings_model = "text-embedding-ada-002"
121+
base_url = "https://api.openai.com/v1"
122+
123+
[ollama]
124+
simple_model = "llama2"
125+
complex_model = "llama2"
126+
embeddings_model = "llama2"
127+
base_url = "http://localhost:11434"
128+
"""
129+
130+
f.write(config_content)
131+
f.close()
132+
133+
def cleanup():
134+
try:
135+
f.close() # Ensure file is properly closed
136+
os.unlink(f.name)
137+
except (OSError, PermissionError):
138+
pass # File might already be deleted or have permission issues
139+
140+
return f.name, cleanup
141+
142+
return _create_config
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from unittest.mock import patch
17+
18+
from rai.initialization import get_tracing_callbacks
19+
20+
21+
class TestInitializationTracing:
22+
"""Test the initialization module's tracing functionality."""
23+
24+
def test_tracing_with_missing_config_file(self):
25+
"""Test that tracing gracefully handles missing config.toml file."""
26+
# This should not crash even without config.toml
27+
callbacks = get_tracing_callbacks()
28+
assert len(callbacks) == 0
29+
30+
def test_tracing_with_config_file_present_tracing_disabled(self, test_config_toml):
31+
"""Test that tracing works when config.toml is present but tracing is disabled."""
32+
config_path, cleanup = test_config_toml(
33+
langfuse_enabled=False, langsmith_enabled=False
34+
)
35+
36+
try:
37+
callbacks = get_tracing_callbacks(config_path=config_path)
38+
# Should return 0 callbacks since both langfuse and langsmith are disabled
39+
assert len(callbacks) == 0
40+
finally:
41+
cleanup()
42+
43+
def test_tracing_with_config_file_present_tracing_enabled(self, test_config_toml):
44+
"""Test that tracing works when config.toml is present and tracing is enabled."""
45+
config_path, cleanup = test_config_toml(
46+
langfuse_enabled=True, langsmith_enabled=False
47+
)
48+
49+
try:
50+
# Mock environment variables to avoid actual API calls
51+
with patch.dict(
52+
os.environ,
53+
{
54+
"LANGFUSE_PUBLIC_KEY": "test_key",
55+
"LANGFUSE_SECRET_KEY": "test_secret",
56+
},
57+
):
58+
callbacks = get_tracing_callbacks(config_path=config_path)
59+
# Should return 1 callback for langfuse
60+
assert len(callbacks) == 1
61+
finally:
62+
cleanup()
63+
64+
def test_tracing_with_valid_config_file_no_tracing(self, test_config_no_tracing):
65+
"""Test that tracing works when config.toml is valid but has no tracing sections."""
66+
config_path, cleanup = test_config_no_tracing()
67+
68+
try:
69+
# This should not crash, should return empty callbacks
70+
callbacks = get_tracing_callbacks(config_path=config_path)
71+
assert len(callbacks) == 0
72+
finally:
73+
cleanup()

0 commit comments

Comments
 (0)