42
42
43
43
router = APIRouter ()
44
44
45
+ def guard_history_is_enabled ():
46
+ return os .environ .get ("GUARD_HISTORY_ENABLED" , "true" ).lower () == "true"
45
47
46
48
@router .get ("/guards" )
47
49
@handle_error
@@ -125,7 +127,12 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
125
127
)
126
128
127
129
if not stream :
128
- validation_outcome : ValidationOutcome = await guard (num_reasks = 0 , ** payload )
130
+ execution = guard (num_reasks = 0 , ** payload )
131
+ if inspect .iscoroutine (execution ):
132
+ validation_outcome : ValidationOutcome = await execution
133
+ else :
134
+ validation_outcome : ValidationOutcome = execution
135
+
129
136
llm_response = guard .history .last .iterations .last .outputs .llm_response_info
130
137
result = outcome_to_chat_completion (
131
138
validation_outcome = validation_outcome ,
@@ -136,13 +143,17 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
136
143
else :
137
144
138
145
async def openai_streamer ():
139
- guard_stream = await guard (num_reasks = 0 , ** payload )
140
- async for result in guard_stream :
141
- chunk = json .dumps (
142
- outcome_to_stream_response (validation_outcome = result )
143
- )
144
- yield f"data: { chunk } \n \n "
145
- yield "\n "
146
+ try :
147
+ guard_stream = await guard (num_reasks = 0 , ** payload )
148
+ async for result in guard_stream :
149
+ chunk = json .dumps (
150
+ outcome_to_stream_response (validation_outcome = result )
151
+ )
152
+ yield f"data: { chunk } \n \n "
153
+ yield "\n "
154
+ except Exception as e :
155
+ yield f"data: { json .dumps ({'error' : {'message' :str (e )}})} \n \n "
156
+ yield "\n "
146
157
147
158
return StreamingResponse (openai_streamer (), media_type = "text/event-stream" )
148
159
@@ -196,58 +207,75 @@ async def validate(guard_name: str, request: Request):
196
207
raise HTTPException (
197
208
status_code = 400 , detail = "Streaming is not supported for parse calls!"
198
209
)
199
- result : ValidationOutcome = guard .parse (
210
+ execution = guard .parse (
200
211
llm_output = llm_output ,
201
212
num_reasks = num_reasks ,
202
213
prompt_params = prompt_params ,
203
214
llm_api = llm_api ,
204
215
** payload ,
205
216
)
217
+ if inspect .iscoroutine (execution ):
218
+ result : ValidationOutcome = await execution
219
+ else :
220
+ result : ValidationOutcome = execution
206
221
else :
207
222
if stream :
208
-
209
223
async def guard_streamer ():
210
- guard_stream = guard (
211
- llm_api = llm_api ,
212
- prompt_params = prompt_params ,
213
- num_reasks = num_reasks ,
214
- stream = stream ,
215
- * args ,
216
- ** payload ,
217
- )
218
- for result in guard_stream :
219
- validation_output = ValidationOutcome .from_guard_history (
220
- guard .history .last
224
+ call = guard (
225
+ llm_api = llm_api ,
226
+ prompt_params = prompt_params ,
227
+ num_reasks = num_reasks ,
228
+ stream = stream ,
229
+ * args ,
230
+ ** payload ,
221
231
)
222
- yield validation_output , result
232
+ is_async = inspect .iscoroutine (call )
233
+ if is_async :
234
+ guard_stream = await call
235
+ async for result in guard_stream :
236
+ validation_output = ValidationOutcome .from_guard_history (
237
+ guard .history .last
238
+ )
239
+ yield validation_output , result
240
+ else :
241
+ guard_stream = call
242
+ for result in guard_stream :
243
+ validation_output = ValidationOutcome .from_guard_history (
244
+ guard .history .last
245
+ )
246
+ yield validation_output , result
223
247
224
248
async def validate_streamer (guard_iter ):
225
- async for validation_output , result in guard_iter :
226
- fragment_dict = result .to_dict ()
227
- fragment_dict ["error_spans" ] = [
249
+ try :
250
+ async for validation_output , result in guard_iter :
251
+ fragment_dict = result .to_dict ()
252
+ fragment_dict ["error_spans" ] = [
253
+ json .dumps ({"start" : x .start , "end" : x .end , "reason" : x .reason })
254
+ for x in guard .error_spans_in_output ()
255
+ ]
256
+ yield json .dumps (fragment_dict ) + "\n "
257
+
258
+ call = guard .history .last
259
+ final_validation_output = ValidationOutcome (
260
+ callId = call .id ,
261
+ validation_passed = result .validation_passed ,
262
+ validated_output = result .validated_output ,
263
+ history = guard .history ,
264
+ raw_llm_output = result .raw_llm_output ,
265
+ )
266
+ final_output_dict = final_validation_output .to_dict ()
267
+ final_output_dict ["error_spans" ] = [
228
268
json .dumps ({"start" : x .start , "end" : x .end , "reason" : x .reason })
229
269
for x in guard .error_spans_in_output ()
230
270
]
231
- yield json .dumps (fragment_dict ) + "\n "
232
-
233
- call = guard .history .last
234
- final_validation_output = ValidationOutcome (
235
- callId = call .id ,
236
- validation_passed = result .validation_passed ,
237
- validated_output = result .validated_output ,
238
- history = guard .history ,
239
- raw_llm_output = result .raw_llm_output ,
240
- )
241
- final_output_dict = final_validation_output .to_dict ()
242
- final_output_dict ["error_spans" ] = [
243
- json .dumps ({"start" : x .start , "end" : x .end , "reason" : x .reason })
244
- for x in guard .error_spans_in_output ()
245
- ]
246
- yield json .dumps (final_output_dict ) + "\n "
247
-
248
- serialized_history = [call .to_dict () for call in guard .history ]
249
- cache_key = f"{ guard .name } -{ final_validation_output .call_id } "
250
- await cache_client .set (cache_key , serialized_history , 300 )
271
+ yield json .dumps (final_output_dict ) + "\n "
272
+ except Exception as e :
273
+ yield json .dumps ({"error" : {"message" : str (e )}}) + "\n "
274
+
275
+ if guard_history_is_enabled ():
276
+ serialized_history = [call .to_dict () for call in guard .history ]
277
+ cache_key = f"{ guard .name } -{ final_validation_output .call_id } "
278
+ await cache_client .set (cache_key , serialized_history , 300 )
251
279
252
280
return StreamingResponse (
253
281
validate_streamer (guard_streamer ()), media_type = "application/json"
@@ -260,15 +288,14 @@ async def validate_streamer(guard_iter):
260
288
* args ,
261
289
** payload ,
262
290
)
263
-
264
291
if inspect .iscoroutine (execution ):
265
292
result : ValidationOutcome = await execution
266
293
else :
267
294
result : ValidationOutcome = execution
268
-
269
- serialized_history = [call .to_dict () for call in guard .history ]
270
- cache_key = f"{ guard .name } -{ result .call_id } "
271
- await cache_client .set (cache_key , serialized_history , 300 )
295
+ if guard_history_is_enabled ():
296
+ serialized_history = [call .to_dict () for call in guard .history ]
297
+ cache_key = f"{ guard .name } -{ result .call_id } "
298
+ await cache_client .set (cache_key , serialized_history , 300 )
272
299
return result .to_dict ()
273
300
274
301
0 commit comments