Skip to content

Commit 7c113ba

Browse files
authored
Merge pull request #80 from guardrails-ai/ai-proxy-updates
add support for streaming validation exceptions and exception handling, disable history support for history because of poor multi node support
2 parents fc7a036 + 6217ffe commit 7c113ba

File tree

3 files changed

+80
-57
lines changed

3 files changed

+80
-57
lines changed

guardrails_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.0-alpha1"
1+
__version__ = "0.1.0-alpha2"

guardrails_api/api/guards.py

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
router = APIRouter()
4444

45+
def guard_history_is_enabled():
46+
return os.environ.get("GUARD_HISTORY_ENABLED", "true").lower() == "true"
4547

4648
@router.get("/guards")
4749
@handle_error
@@ -125,7 +127,12 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
125127
)
126128

127129
if not stream:
128-
validation_outcome: ValidationOutcome = await guard(num_reasks=0, **payload)
130+
execution = guard(num_reasks=0, **payload)
131+
if inspect.iscoroutine(execution):
132+
validation_outcome: ValidationOutcome = await execution
133+
else:
134+
validation_outcome: ValidationOutcome = execution
135+
129136
llm_response = guard.history.last.iterations.last.outputs.llm_response_info
130137
result = outcome_to_chat_completion(
131138
validation_outcome=validation_outcome,
@@ -136,13 +143,17 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
136143
else:
137144

138145
async def openai_streamer():
139-
guard_stream = await guard(num_reasks=0, **payload)
140-
async for result in guard_stream:
141-
chunk = json.dumps(
142-
outcome_to_stream_response(validation_outcome=result)
143-
)
144-
yield f"data: {chunk}\n\n"
145-
yield "\n"
146+
try:
147+
guard_stream = await guard(num_reasks=0, **payload)
148+
async for result in guard_stream:
149+
chunk = json.dumps(
150+
outcome_to_stream_response(validation_outcome=result)
151+
)
152+
yield f"data: {chunk}\n\n"
153+
yield "\n"
154+
except Exception as e:
155+
yield f"data: {json.dumps({'error': {'message':str(e)}})}\n\n"
156+
yield "\n"
146157

147158
return StreamingResponse(openai_streamer(), media_type="text/event-stream")
148159

@@ -196,58 +207,75 @@ async def validate(guard_name: str, request: Request):
196207
raise HTTPException(
197208
status_code=400, detail="Streaming is not supported for parse calls!"
198209
)
199-
result: ValidationOutcome = guard.parse(
210+
execution = guard.parse(
200211
llm_output=llm_output,
201212
num_reasks=num_reasks,
202213
prompt_params=prompt_params,
203214
llm_api=llm_api,
204215
**payload,
205216
)
217+
if inspect.iscoroutine(execution):
218+
result: ValidationOutcome = await execution
219+
else:
220+
result: ValidationOutcome = execution
206221
else:
207222
if stream:
208-
209223
async def guard_streamer():
210-
guard_stream = guard(
211-
llm_api=llm_api,
212-
prompt_params=prompt_params,
213-
num_reasks=num_reasks,
214-
stream=stream,
215-
*args,
216-
**payload,
217-
)
218-
for result in guard_stream:
219-
validation_output = ValidationOutcome.from_guard_history(
220-
guard.history.last
224+
call = guard(
225+
llm_api=llm_api,
226+
prompt_params=prompt_params,
227+
num_reasks=num_reasks,
228+
stream=stream,
229+
*args,
230+
**payload,
221231
)
222-
yield validation_output, result
232+
is_async = inspect.iscoroutine(call)
233+
if is_async:
234+
guard_stream = await call
235+
async for result in guard_stream:
236+
validation_output = ValidationOutcome.from_guard_history(
237+
guard.history.last
238+
)
239+
yield validation_output, result
240+
else:
241+
guard_stream = call
242+
for result in guard_stream:
243+
validation_output = ValidationOutcome.from_guard_history(
244+
guard.history.last
245+
)
246+
yield validation_output, result
223247

224248
async def validate_streamer(guard_iter):
225-
async for validation_output, result in guard_iter:
226-
fragment_dict = result.to_dict()
227-
fragment_dict["error_spans"] = [
249+
try:
250+
async for validation_output, result in guard_iter:
251+
fragment_dict = result.to_dict()
252+
fragment_dict["error_spans"] = [
253+
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
254+
for x in guard.error_spans_in_output()
255+
]
256+
yield json.dumps(fragment_dict) + "\n"
257+
258+
call = guard.history.last
259+
final_validation_output = ValidationOutcome(
260+
callId=call.id,
261+
validation_passed=result.validation_passed,
262+
validated_output=result.validated_output,
263+
history=guard.history,
264+
raw_llm_output=result.raw_llm_output,
265+
)
266+
final_output_dict = final_validation_output.to_dict()
267+
final_output_dict["error_spans"] = [
228268
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
229269
for x in guard.error_spans_in_output()
230270
]
231-
yield json.dumps(fragment_dict) + "\n"
232-
233-
call = guard.history.last
234-
final_validation_output = ValidationOutcome(
235-
callId=call.id,
236-
validation_passed=result.validation_passed,
237-
validated_output=result.validated_output,
238-
history=guard.history,
239-
raw_llm_output=result.raw_llm_output,
240-
)
241-
final_output_dict = final_validation_output.to_dict()
242-
final_output_dict["error_spans"] = [
243-
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
244-
for x in guard.error_spans_in_output()
245-
]
246-
yield json.dumps(final_output_dict) + "\n"
247-
248-
serialized_history = [call.to_dict() for call in guard.history]
249-
cache_key = f"{guard.name}-{final_validation_output.call_id}"
250-
await cache_client.set(cache_key, serialized_history, 300)
271+
yield json.dumps(final_output_dict) + "\n"
272+
except Exception as e:
273+
yield json.dumps({"error": {"message": str(e)}}) + "\n"
274+
275+
if guard_history_is_enabled():
276+
serialized_history = [call.to_dict() for call in guard.history]
277+
cache_key = f"{guard.name}-{final_validation_output.call_id}"
278+
await cache_client.set(cache_key, serialized_history, 300)
251279

252280
return StreamingResponse(
253281
validate_streamer(guard_streamer()), media_type="application/json"
@@ -260,15 +288,14 @@ async def validate_streamer(guard_iter):
260288
*args,
261289
**payload,
262290
)
263-
264291
if inspect.iscoroutine(execution):
265292
result: ValidationOutcome = await execution
266293
else:
267294
result: ValidationOutcome = execution
268-
269-
serialized_history = [call.to_dict() for call in guard.history]
270-
cache_key = f"{guard.name}-{result.call_id}"
271-
await cache_client.set(cache_key, serialized_history, 300)
295+
if guard_history_is_enabled():
296+
serialized_history = [call.to_dict() for call in guard.history]
297+
cache_key = f"{guard.name}-{result.call_id}"
298+
await cache_client.set(cache_key, serialized_history, 300)
272299
return result.to_dict()
273300

274301

tests/api/test_guards.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from tests.mocks.mock_guard_client import MockGuardStruct
1515
from guardrails_api.api.guards import router as guards_router
1616

17-
18-
import asyncio
19-
2017
# TODO: Should we mock this somehow?
2118
# Right now it's just empty, but it technically does a file read
2219
register_config()
@@ -347,9 +344,8 @@ def test_openai_v1_chat_completions__call(mocker):
347344
)
348345

349346
mock___call__ = mocker.patch.object(MockGuardStruct, "__call__")
350-
future = asyncio.Future()
351-
future.set_result(mock_outcome)
352-
mock___call__.return_value = future
347+
348+
mock___call__.return_value = mock_outcome
353349

354350
mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict")
355351
mock_from_dict.return_value = mock_guard

0 commit comments

Comments
 (0)