-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix: Missing injected prompt in conversation context, prompt injected using before_model_callback method #3148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,6 +16,39 @@ async def before_model_callback(callback_context, llm_request): | |||||||||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
INJECTIONS_STATE_KEY = "__persisted_prompt_injections" | ||||||||||||||||||||||||||||||||||||||||||
_index = 0 | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def before_model_callback_persist_injections(callback_context, llm_request): | ||||||||||||||||||||||||||||||||||||||||||
"""Sample before_model_callback that persists prompt injections in state. | ||||||||||||||||||||||||||||||||||||||||||
This function demonstrates how user code can store small text injections | ||||||||||||||||||||||||||||||||||||||||||
in the callback context state so future requests will include them. | ||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||
global _index | ||||||||||||||||||||||||||||||||||||||||||
if not getattr(llm_request, 'contents', None): | ||||||||||||||||||||||||||||||||||||||||||
llm_request.contents = [] | ||||||||||||||||||||||||||||||||||||||||||
injections = callback_context.state.get(INJECTIONS_STATE_KEY, []) | ||||||||||||||||||||||||||||||||||||||||||
for inj in injections: | ||||||||||||||||||||||||||||||||||||||||||
found = False | ||||||||||||||||||||||||||||||||||||||||||
for c in llm_request.contents: | ||||||||||||||||||||||||||||||||||||||||||
for part in getattr(c, 'parts', []): | ||||||||||||||||||||||||||||||||||||||||||
if getattr(part, 'text', None) == inj: | ||||||||||||||||||||||||||||||||||||||||||
found = True | ||||||||||||||||||||||||||||||||||||||||||
break | ||||||||||||||||||||||||||||||||||||||||||
if found: | ||||||||||||||||||||||||||||||||||||||||||
break | ||||||||||||||||||||||||||||||||||||||||||
if not found: | ||||||||||||||||||||||||||||||||||||||||||
llm_request.contents.insert(0, types.Content(role='user', parts=[types.Part(text=inj)])) | ||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+32
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The nested loops to check if an injection already exists are inefficient. This can be optimized by first building a set of all existing text from
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
new_inj = f"Ignore this user input: Index={_index}" | ||||||||||||||||||||||||||||||||||||||||||
llm_request.contents.append(types.Content(role='user', parts=[types.Part(text=new_inj)])) | ||||||||||||||||||||||||||||||||||||||||||
injections.append(new_inj) | ||||||||||||||||||||||||||||||||||||||||||
callback_context.state[INJECTIONS_STATE_KEY] = injections | ||||||||||||||||||||||||||||||||||||||||||
_index += 1 | ||||||||||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
async def after_model_callback(callback_context, llm_response): | ||||||||||||||||||||||||||||||||||||||||||
print('@after_model_callback') | ||||||||||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -799,6 +799,59 @@ async def _handle_before_model_callback( | |
invocation_context, event_actions=model_response_event.actions | ||
) | ||
|
||
state = getattr(callback_context, 'state', None) | ||
if state is None: | ||
injections = None | ||
else: | ||
injections = state.get("__persisted_prompt_injections") | ||
if not injections: | ||
try: | ||
session_state = getattr(invocation_context, 'session', None) | ||
if session_state and getattr(session_state, 'state', None) is not None: | ||
persisted = session_state.state.get("__persisted_prompt_injections") | ||
if persisted: | ||
injections = persisted | ||
except Exception: | ||
injections = injections | ||
|
||
if injections: | ||
if not isinstance(injections, list): | ||
injections = [injections] | ||
for inj in injections: | ||
if not inj: | ||
continue | ||
already = False | ||
inj_text = None | ||
if isinstance(inj, str): | ||
inj_text = inj | ||
else: | ||
parts = getattr(inj, 'parts', None) | ||
if parts and len(parts) and hasattr(parts[0], 'text'): | ||
inj_text = parts[0].text | ||
else: | ||
inj_text = None | ||
|
||
if inj_text: | ||
for c in list(llm_request.contents or []): | ||
if not c.parts: | ||
continue | ||
for part in c.parts: | ||
if part and getattr(part, 'text', None) == inj_text: | ||
already = True | ||
break | ||
if already: | ||
break | ||
if already: | ||
continue | ||
|
||
llm_request.contents = llm_request.contents or [] | ||
if isinstance(inj, str): | ||
llm_request.contents.insert( | ||
0, types.Content(role="user", parts=[types.Part(text=inj)]) | ||
) | ||
else: | ||
llm_request.contents.insert(0, inj) | ||
Comment on lines
+802
to
+853
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This new block for handling prompt injections has several areas for improvement regarding safety, efficiency, and maintainability:
Here is a refactored version of the block that addresses these points: injections = callback_context.state.get("__persisted_prompt_injections")
if not injections:
session = getattr(invocation_context, 'session', None)
if session:
injections = session.state.get("__persisted_prompt_injections")
if injections:
if not isinstance(injections, list):
injections = [injections]
llm_request.contents = llm_request.contents or []
existing_texts = {
part.text
for c in llm_request.contents
for part in (c.parts or [])
if getattr(part, 'text', None)
}
for inj in injections:
if not inj:
continue
inj_text = None
if isinstance(inj, str):
inj_text = inj
elif getattr(inj, 'parts', None) and inj.parts and hasattr(inj.parts[0], 'text'):
inj_text = inj.parts[0].text
if inj_text and inj_text in existing_texts:
continue
if isinstance(inj, str):
llm_request.contents.insert(
0, types.Content(role="user", parts=[types.Part(text=inj)])
)
else:
llm_request.contents.insert(0, inj) |
||
|
||
# First run callbacks from the plugins. | ||
callback_response = ( | ||
await invocation_context.plugin_manager.run_before_model_callback( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a global variable
_index
is not thread-safe and is not a recommended pattern, especially for a sample file that users might copy. In a real-world scenario with concurrent requests, this could lead to race conditions. A better approach is to store the index in thecallback_context.state
, which is scoped to the user's session. For example: