Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,13 +483,17 @@ def _model_response_to_generate_content_response(
"""

message = None
finish_reason = None
if response.get("choices", None):
message = response["choices"][0].get("message", None)
finish_reason = response["choices"][0].get("finish_reason", None)

if not message:
raise ValueError("No message in response")

llm_response = _message_to_generate_content_response(message)
if finish_reason:
llm_response.finish_reason = finish_reason
if response.get("usage", None):
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=response["usage"].get("prompt_tokens", 0),
Expand Down
6 changes: 5 additions & 1 deletion src/google/adk/telemetry/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,13 @@ def trace_call_llm(
llm_response.usage_metadata.candidates_token_count,
)
if llm_response.finish_reason:
if hasattr(llm_response.finish_reason, 'value'):
finish_reason_str = llm_response.finish_reason.value.lower()
else:
finish_reason_str = str(llm_response.finish_reason).lower()
span.set_attribute(
'gen_ai.response.finish_reasons',
[llm_response.finish_reason.value.lower()],
[finish_reason_str],
)


Expand Down
140 changes: 140 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,3 +1849,143 @@ def test_non_gemini_litellm_no_warning():
# Test with non-Gemini model
LiteLlm(model="openai/gpt-4o")
assert len(w) == 0


@pytest.mark.asyncio
async def test_finish_reason_propagation_non_streaming(
mock_acompletion, lite_llm_instance
):
"""Test that finish_reason is properly propagated from LiteLLM response in non-streaming mode."""
mock_response_with_finish_reason = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
),
finish_reason="length",
)
]
)
mock_acompletion.return_value = mock_response_with_finish_reason

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.finish_reason == "length"

mock_acompletion.assert_called_once()


@pytest.mark.asyncio
async def test_finish_reason_propagation_stop(
mock_acompletion, lite_llm_instance
):
"""Test that finish_reason='stop' is properly propagated."""
mock_response_with_finish_reason = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Complete response",
),
finish_reason="stop",
)
]
)
mock_acompletion.return_value = mock_response_with_finish_reason

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.finish_reason == "stop"

mock_acompletion.assert_called_once()


@pytest.mark.asyncio
async def test_finish_reason_propagation_tool_calls(
mock_acompletion, lite_llm_instance
):
"""Test that finish_reason='tool_calls' is properly propagated."""
mock_response_with_finish_reason = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id="test_id",
function=Function(
name="test_function",
arguments='{"arg": "value"}',
),
)
],
),
finish_reason="tool_calls",
)
]
)
mock_acompletion.return_value = mock_response_with_finish_reason

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.finish_reason == "tool_calls"

mock_acompletion.assert_called_once()


@pytest.mark.asyncio
async def test_finish_reason_content_filter(
mock_acompletion, lite_llm_instance
):
"""Test that finish_reason='content_filter' is properly propagated."""
mock_response_with_content_filter = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="",
),
finish_reason="content_filter",
)
]
)
mock_acompletion.return_value = mock_response_with_content_filter

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.finish_reason == "content_filter"

mock_acompletion.assert_called_once()