-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix: Missing injected prompt in conversation context, prompt injected using before_model_callback method #3148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
… and local genai stubs
… and local genai stubss
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello @coder-jatin-s, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves an issue where prompt injections made through the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for persistent prompt injections using a before_model_callback
. The changes look good conceptually, but there are a few areas for improvement in the implementation. My review focuses on enhancing code safety, efficiency, and maintainability, particularly in base_llm_flow.py
where I've suggested a refactoring to address a risky try-except
block, use of magic strings, and inefficient looping. I've also provided feedback on the sample callback to promote better practices.
state = getattr(callback_context, 'state', None) | ||
if state is None: | ||
injections = None | ||
else: | ||
injections = state.get("__persisted_prompt_injections") | ||
if not injections: | ||
try: | ||
session_state = getattr(invocation_context, 'session', None) | ||
if session_state and getattr(session_state, 'state', None) is not None: | ||
persisted = session_state.state.get("__persisted_prompt_injections") | ||
if persisted: | ||
injections = persisted | ||
except Exception: | ||
injections = injections | ||
|
||
if injections: | ||
if not isinstance(injections, list): | ||
injections = [injections] | ||
for inj in injections: | ||
if not inj: | ||
continue | ||
already = False | ||
inj_text = None | ||
if isinstance(inj, str): | ||
inj_text = inj | ||
else: | ||
parts = getattr(inj, 'parts', None) | ||
if parts and len(parts) and hasattr(parts[0], 'text'): | ||
inj_text = parts[0].text | ||
else: | ||
inj_text = None | ||
|
||
if inj_text: | ||
for c in list(llm_request.contents or []): | ||
if not c.parts: | ||
continue | ||
for part in c.parts: | ||
if part and getattr(part, 'text', None) == inj_text: | ||
already = True | ||
break | ||
if already: | ||
break | ||
if already: | ||
continue | ||
|
||
llm_request.contents = llm_request.contents or [] | ||
if isinstance(inj, str): | ||
llm_request.contents.insert( | ||
0, types.Content(role="user", parts=[types.Part(text=inj)]) | ||
) | ||
else: | ||
llm_request.contents.insert(0, inj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new block for handling prompt injections has several areas for improvement regarding safety, efficiency, and maintainability:
- Magic String: The key
__persisted_prompt_injections
is hardcoded. This should be defined as a module-level constant (e.g.,INJECTIONS_STATE_KEY
) and used consistently to avoid typos and improve maintainability. - Unsafe Exception Handling: The
try...except Exception:
block silences all errors, which can hide critical bugs. It's better to use safer attribute access. - Inefficient Duplicate Check: The nested loops for checking for duplicate injections are inefficient. Using a
set
for lookups would be more performant. - Overly Defensive Code: The checks for
callback_context.state
beingNone
are unnecessary as it's guaranteed to be aState
object.
Here is a refactored version of the block that addresses these points:
injections = callback_context.state.get("__persisted_prompt_injections")
if not injections:
session = getattr(invocation_context, 'session', None)
if session:
injections = session.state.get("__persisted_prompt_injections")
if injections:
if not isinstance(injections, list):
injections = [injections]
llm_request.contents = llm_request.contents or []
existing_texts = {
part.text
for c in llm_request.contents
for part in (c.parts or [])
if getattr(part, 'text', None)
}
for inj in injections:
if not inj:
continue
inj_text = None
if isinstance(inj, str):
inj_text = inj
elif getattr(inj, 'parts', None) and inj.parts and hasattr(inj.parts[0], 'text'):
inj_text = inj.parts[0].text
if inj_text and inj_text in existing_texts:
continue
if isinstance(inj, str):
llm_request.contents.insert(
0, types.Content(role="user", parts=[types.Part(text=inj)])
)
else:
llm_request.contents.insert(0, inj)
This function demonstrates how user code can store small text injections | ||
in the callback context state so future requests will include them. | ||
""" | ||
global _index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a global variable _index
is not thread-safe and is not a recommended pattern, especially for a sample file that users might copy. In a real-world scenario with concurrent requests, this could lead to race conditions. A better approach is to store the index in the callback_context.state
, which is scoped to the user's session. For example:
_index = callback_context.state.setdefault('my_injection_index', -1) + 1
callback_context.state['my_injection_index'] = _index
for inj in injections: | ||
found = False | ||
for c in llm_request.contents: | ||
for part in getattr(c, 'parts', []): | ||
if getattr(part, 'text', None) == inj: | ||
found = True | ||
break | ||
if found: | ||
break | ||
if not found: | ||
llm_request.contents.insert(0, types.Content(role='user', parts=[types.Part(text=inj)])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested loops to check if an injection already exists are inefficient. This can be optimized by first building a set of all existing text from llm_request.contents
for fast O(1) lookups.
for inj in injections: | |
found = False | |
for c in llm_request.contents: | |
for part in getattr(c, 'parts', []): | |
if getattr(part, 'text', None) == inj: | |
found = True | |
break | |
if found: | |
break | |
if not found: | |
llm_request.contents.insert(0, types.Content(role='user', parts=[types.Part(text=inj)])) | |
existing_texts = { | |
part.text | |
for c in llm_request.contents | |
for part in getattr(c, 'parts', []) | |
if getattr(part, 'text') | |
} | |
for inj in injections: | |
if inj not in existing_texts: | |
llm_request.contents.insert(0, types.Content(role='user', parts=[types.Part(text=inj)])) |
HI @coder-jatin-s, Thanks for your contribution! This looks promising. Here are the next steps Sign the CLA: This is the first requirement before we can accept the code. Add Unit Tests: After the CLA is signed, please add tests to cover the new functionality. Let us know if you have any questions! |
Fix #3138
I added a persistent prompt injection support and tests: introduced INJECTIONS_STATE_KEY and merged persisted prompt injections from callback/session state, added a sample before-model callback that persists injections and a unit test to verify cross-turn persistence