Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions contributing/samples/core_callback_config/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,39 @@ async def before_model_callback(callback_context, llm_request):
return None


INJECTIONS_STATE_KEY = "__persisted_prompt_injections"
_index = 0

def before_model_callback_persist_injections(callback_context, llm_request):
"""Sample before_model_callback that persists prompt injections in state.
This function demonstrates how user code can store small text injections
in the callback context state so future requests will include them.
"""
global _index

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a global variable _index is not thread-safe and is not a recommended pattern, especially for a sample file that users might copy. In a real-world scenario with concurrent requests, this could lead to race conditions. A better approach is to store the index in the callback_context.state, which is scoped to the user's session. For example:

_index = callback_context.state.setdefault('my_injection_index', -1) + 1
callback_context.state['my_injection_index'] = _index

if not getattr(llm_request, 'contents', None):
llm_request.contents = []
injections = callback_context.state.get(INJECTIONS_STATE_KEY, [])
for inj in injections:
found = False
for c in llm_request.contents:
for part in getattr(c, 'parts', []):
if getattr(part, 'text', None) == inj:
found = True
break
if found:
break
if not found:
llm_request.contents.insert(0, types.Content(role='user', parts=[types.Part(text=inj)]))
Comment on lines +32 to +42

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested loops to check if an injection already exists are inefficient. This can be optimized by first building a set of all existing text from llm_request.contents for fast O(1) lookups.

Suggested change
for inj in injections:
found = False
for c in llm_request.contents:
for part in getattr(c, 'parts', []):
if getattr(part, 'text', None) == inj:
found = True
break
if found:
break
if not found:
llm_request.contents.insert(0, types.Content(role='user', parts=[types.Part(text=inj)]))
existing_texts = {
part.text
for c in llm_request.contents
for part in getattr(c, 'parts', [])
if getattr(part, 'text')
}
for inj in injections:
if inj not in existing_texts:
llm_request.contents.insert(0, types.Content(role='user', parts=[types.Part(text=inj)]))


new_inj = f"Ignore this user input: Index={_index}"
llm_request.contents.append(types.Content(role='user', parts=[types.Part(text=new_inj)]))
injections.append(new_inj)
callback_context.state[INJECTIONS_STATE_KEY] = injections
_index += 1
return None


async def after_model_callback(callback_context, llm_response):
print('@after_model_callback')
return None
Expand Down
53 changes: 53 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,59 @@ async def _handle_before_model_callback(
invocation_context, event_actions=model_response_event.actions
)

state = getattr(callback_context, 'state', None)
if state is None:
injections = None
else:
injections = state.get("__persisted_prompt_injections")
if not injections:
try:
session_state = getattr(invocation_context, 'session', None)
if session_state and getattr(session_state, 'state', None) is not None:
persisted = session_state.state.get("__persisted_prompt_injections")
if persisted:
injections = persisted
except Exception:
injections = injections

if injections:
if not isinstance(injections, list):
injections = [injections]
for inj in injections:
if not inj:
continue
already = False
inj_text = None
if isinstance(inj, str):
inj_text = inj
else:
parts = getattr(inj, 'parts', None)
if parts and len(parts) and hasattr(parts[0], 'text'):
inj_text = parts[0].text
else:
inj_text = None

if inj_text:
for c in list(llm_request.contents or []):
if not c.parts:
continue
for part in c.parts:
if part and getattr(part, 'text', None) == inj_text:
already = True
break
if already:
break
if already:
continue

llm_request.contents = llm_request.contents or []
if isinstance(inj, str):
llm_request.contents.insert(
0, types.Content(role="user", parts=[types.Part(text=inj)])
)
else:
llm_request.contents.insert(0, inj)
Comment on lines +802 to +853

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new block for handling prompt injections has several areas for improvement regarding safety, efficiency, and maintainability:

  1. Magic String: The key __persisted_prompt_injections is hardcoded. This should be defined as a module-level constant (e.g., INJECTIONS_STATE_KEY) and used consistently to avoid typos and improve maintainability.
  2. Unsafe Exception Handling: The try...except Exception: block silences all errors, which can hide critical bugs. It's better to use safer attribute access.
  3. Inefficient Duplicate Check: The nested loops for checking for duplicate injections are inefficient. Using a set for lookups would be more performant.
  4. Overly Defensive Code: The checks for callback_context.state being None are unnecessary as it's guaranteed to be a State object.

Here is a refactored version of the block that addresses these points:

    injections = callback_context.state.get("__persisted_prompt_injections")
    if not injections:
      session = getattr(invocation_context, 'session', None)
      if session:
        injections = session.state.get("__persisted_prompt_injections")

    if injections:
      if not isinstance(injections, list):
        injections = [injections]

      llm_request.contents = llm_request.contents or []
      existing_texts = {
          part.text
          for c in llm_request.contents
          for part in (c.parts or [])
          if getattr(part, 'text', None)
      }

      for inj in injections:
        if not inj:
          continue

        inj_text = None
        if isinstance(inj, str):
          inj_text = inj
        elif getattr(inj, 'parts', None) and inj.parts and hasattr(inj.parts[0], 'text'):
          inj_text = inj.parts[0].text

        if inj_text and inj_text in existing_texts:
          continue

        if isinstance(inj, str):
          llm_request.contents.insert(
              0, types.Content(role="user", parts=[types.Part(text=inj)])
          )
        else:
          llm_request.contents.insert(0, inj)


# First run callbacks from the plugins.
callback_response = (
await invocation_context.plugin_manager.run_before_model_callback(
Expand Down
Loading