Skip to content

Commit 6074a1b

Browse files
authored
fix: refactor messages to use BaseMessage.text() (#683)
1 parent e567d5b commit 6074a1b

File tree

5 files changed

+52
-7
lines changed

5 files changed

+52
-7
lines changed

src/rai_bench/rai_bench/manipulation_o3de/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
286286

287287
for msg in new_messages:
288288
if isinstance(msg, HumanMultimodalMessage):
289-
last_msg = msg.text
289+
last_msg = msg.text()
290290
elif isinstance(msg, BaseMessage):
291291
if isinstance(msg.content, list):
292292
if len(msg.content) == 1:

src/rai_core/rai/communication/hri_connector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,11 @@ def from_langchain(
103103
seq_no: int = 0,
104104
seq_end: bool = False,
105105
) -> "HRIMessage":
106+
text = message.text()
106107
if isinstance(message, RAIMultimodalMessage):
107-
text = message.text
108108
images = message.images
109109
audios = message.audios
110110
else:
111-
text = str(message.content)
112111
images = None
113112
audios = None
114113
if message.type not in ["ai", "human"]:

src/rai_core/rai/messages/multimodal.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ def __init__(
6363
_content.extend(_image_content)
6464
self.content = _content
6565

66-
@property
67-
def text(self) -> str:
68-
return self.content[0]["text"]
69-
7066

7167
class HumanMultimodalMessage(HumanMessage, MultimodalMessage):
7268
def __repr_args__(self) -> Any:

tests/agents/langchain/test_langchain_agent.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
from unittest.mock import MagicMock, patch
1919

2020
import pytest
21+
from langchain_core.callbacks import BaseCallbackHandler
22+
from langchain_core.language_models.fake_chat_models import ParrotFakeChatModel
23+
from langchain_core.runnables import RunnableConfig
2124
from rai.agents.langchain import invoke_llm_with_tracing
2225
from rai.agents.langchain.agent import LangChainAgent, newMessageBehaviorType
2326
from rai.initialization import get_tracing_callbacks
27+
from rai.messages import HumanMultimodalMessage
2428

2529

2630
@pytest.mark.parametrize(
@@ -150,3 +154,12 @@ def test_invoke_llm_with_existing_config(self):
150154
assert "callbacks" in call_args[1]["config"]
151155
assert "existing_callback" in call_args[1]["config"]["callbacks"]
152156
assert "tracing_callback" in call_args[1]["config"]["callbacks"]
157+
158+
def test_invoke_llm_with_callback_integration(self):
159+
"""Test that invoke_llm_with_tracing works with a callback handler."""
160+
llm = ParrotFakeChatModel()
161+
human_msg = HumanMultimodalMessage(content="human")
162+
response = llm.invoke(
163+
[human_msg], config=RunnableConfig(callbacks=[BaseCallbackHandler()])
164+
)
165+
assert response.content == [{"type": "text", "text": "human"}]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from rai.messages import HumanMultimodalMessage
17+
18+
19+
class TestMultimodalMessage:
20+
"""Test the MultimodalMessage class and expected behaviors."""
21+
22+
def test_human_multimodal_message_text_simple(self):
23+
"""Test text() method with simple text content."""
24+
msg = HumanMultimodalMessage(content="Hello world")
25+
assert msg.text() == "Hello world"
26+
assert isinstance(msg.text(), str)
27+
28+
def test_human_multimodal_message_text_with_images(self):
29+
"""Test text() method with text and images."""
30+
# Use a small valid base64 image (1x1 pixel PNG)
31+
valid_base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
32+
msg = HumanMultimodalMessage(
33+
content="Look at this image", images=[valid_base64_image]
34+
)
35+
assert msg.text() == "Look at this image"
36+
# Should only return text type blocks, not image content
37+
assert valid_base64_image not in msg.text()

0 commit comments

Comments
 (0)