Skip to content

Commit 8498640

Browse files
authored
Merge pull request #84 from guardrails-ai/forward-compatibility
Forward compatibility
2 parents 420fa37 + 4e36257 commit 8498640

File tree

13 files changed

+98
-42
lines changed

13 files changed

+98
-42
lines changed

guardrails_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.4"
1+
__version__ = "0.0.5"

guardrails_api/app.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from opentelemetry.instrumentation.flask import FlaskInstrumentor
1010
from guardrails_api.clients.postgres_client import postgres_is_enabled
1111
from guardrails_api.otel import otel_is_disabled, initialize
12-
from guardrails_api.utils.trace_server_start_if_enabled import trace_server_start_if_enabled
12+
from guardrails_api.utils.trace_server_start_if_enabled import (
13+
trace_server_start_if_enabled,
14+
)
1315
from guardrails_api.clients.cache_client import CacheClient
1416
from rich.console import Console
1517
from rich.rule import Rule
@@ -84,7 +86,7 @@ def create_app(
8486

8587
@app.before_request
8688
def basic_cors():
87-
if request.method.lower() == 'options':
89+
if request.method.lower() == "options":
8890
return Response()
8991

9092
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1)
@@ -112,20 +114,25 @@ def basic_cors():
112114
app.register_blueprint(root_bp)
113115
app.register_blueprint(guards_bp)
114116

117+
console.print(f"\n:rocket: Guardrails API is available at {self_endpoint}")
115118
console.print(
116-
f"\n:rocket: Guardrails API is available at {self_endpoint}"
119+
f":book: Visit {self_endpoint}/docs to see available API endpoints.\n"
117120
)
118-
console.print(f":book: Visit {self_endpoint}/docs to see available API endpoints.\n")
119121

120122
console.print(":green_circle: Active guards and OpenAI compatible endpoints:")
121123

122124
with app.app_context():
123125
from guardrails_api.blueprints.guards import guard_client
126+
124127
for g in guard_client.get_guards():
125128
g = g.to_dict()
126-
console.print(f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1")
129+
console.print(
130+
f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1"
131+
)
127132

128133
console.print("")
129-
console.print(Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white"))
134+
console.print(
135+
Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white")
136+
)
130137

131-
return app
138+
return app

guardrails_api/blueprints/guards.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from guardrails_api.clients.postgres_client import postgres_is_enabled
1919
from guardrails_api.utils.handle_error import handle_error
2020
from guardrails_api.utils.get_llm_callable import get_llm_callable
21-
from guardrails_api.utils.openai import outcome_to_chat_completion, outcome_to_stream_response
21+
from guardrails_api.utils.openai import (
22+
outcome_to_chat_completion,
23+
outcome_to_stream_response,
24+
)
2225

2326
guards_bp = Blueprint("guards", __name__, url_prefix="/guards")
2427

@@ -272,7 +275,6 @@ def validate(guard_name: str):
272275
# ) as validate_span:
273276
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)
274277

275-
276278
# validate_span.set_attribute("guardName", decoded_guard_name)
277279
if llm_api is not None:
278280
llm_api = get_llm_callable(llm_api)
@@ -295,7 +297,7 @@ def validate(guard_name: str):
295297
else:
296298
guard: Guard = Guard.from_dict(guard_struct.to_dict())
297299
elif is_async:
298-
guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict())
300+
guard: Guard = AsyncGuard.from_dict(guard_struct.to_dict())
299301

300302
if llm_api is None and num_reasks and num_reasks > 1:
301303
raise HttpError(
@@ -322,6 +324,7 @@ def validate(guard_name: str):
322324
)
323325
else:
324326
if stream:
327+
325328
def guard_streamer():
326329
guard_stream = guard(
327330
llm_api=llm_api,
@@ -452,24 +455,30 @@ async def async_validate_streamer(guard_iter):
452455
cache_key = f"{guard.name}-{final_validation_output.call_id}"
453456
cache_client.set(cache_key, serialized_history, 300)
454457
yield f"{final_output_json}\n"
458+
455459
# apropos of https://stackoverflow.com/questions/73949570/using-stream-with-context-as-async
456460
def iter_over_async(ait, loop):
457461
ait = ait.__aiter__()
462+
458463
async def get_next():
459-
try:
464+
try:
460465
obj = await ait.__anext__()
461466
return False, obj
462-
except StopAsyncIteration:
467+
except StopAsyncIteration:
463468
return True, None
469+
464470
while True:
465471
done, obj = loop.run_until_complete(get_next())
466-
if done:
472+
if done:
467473
break
468474
yield obj
475+
469476
if is_async:
470477
loop = asyncio.new_event_loop()
471478
asyncio.set_event_loop(loop)
472-
iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop)
479+
iter = iter_over_async(
480+
async_validate_streamer(async_guard_streamer()), loop
481+
)
473482
else:
474483
iter = validate_streamer(guard_streamer())
475484
return Response(

guardrails_api/cli/start.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from guardrails_api.app import create_app
55
from guardrails_api.utils.configuration import valid_configuration
66

7+
78
@cli.command("start")
89
def start(
910
env: Optional[str] = typer.Option(

guardrails_api/utils/configuration.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,32 @@
22
from typing import Optional
33
import os
44

5-
def valid_configuration(config: Optional[str]=""):
5+
6+
def valid_configuration(config: Optional[str] = ""):
67
default_config_file = os.path.join(os.getcwd(), "./config.py")
78

89
default_config_file_path = os.path.abspath(default_config_file)
9-
# If config.py is not present and
10+
# If config.py is not present and
1011
# if a config filepath is not passed and
11-
# if postgres is not there (i.e. we’re using in-mem db)
12+
# if postgres is not there (i.e. we’re using in-mem db)
1213
# then raise ConfigurationError
1314
has_default_config_file = os.path.isfile(default_config_file_path)
1415

15-
has_config_file = (config != "" and config is not None) and os.path.isfile(os.path.abspath(config))
16-
if not has_default_config_file and not has_config_file and not postgres_is_enabled():
17-
raise ConfigurationError("Can not start. Configuration not provided and default"
18-
" configuration not found and postgres is not enabled.")
16+
has_config_file = (config != "" and config is not None) and os.path.isfile(
17+
os.path.abspath(config)
18+
)
19+
20+
if (
21+
not has_default_config_file
22+
and not has_config_file
23+
and not postgres_is_enabled()
24+
):
25+
raise ConfigurationError(
26+
"Can not start. Configuration not provided and default"
27+
" configuration not found and postgres is not enabled."
28+
)
1929
return True
2030

31+
2132
class ConfigurationError(Exception):
22-
pass
33+
pass

guardrails_api/utils/handle_error.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,35 @@ def decorator(*args, **kwargs):
1313
return fn(*args, **kwargs)
1414
except ValidationError as validation_error:
1515
logger.error(validation_error)
16-
traceback.print_exception(type(validation_error), validation_error, validation_error.__traceback__)
17-
return str(validation_error), 400
16+
traceback.print_exception(
17+
type(validation_error), validation_error, validation_error.__traceback__
18+
)
19+
resp_body = {"status_code": 400, "detail": str(validation_error)}
20+
return resp_body, 400
1821
except HttpError as http_error:
1922
logger.error(http_error)
20-
traceback.print_exception(type(http_error), http_error, http_error.__traceback__)
21-
return http_error.to_dict(), http_error.status
23+
traceback.print_exception(
24+
type(http_error), http_error, http_error.__traceback__
25+
)
26+
resp_body = http_error.to_dict()
27+
resp_body["status_code"] = http_error.status
28+
resp_body["detail"] = http_error.message
29+
return resp_body, http_error.status
2230
except HTTPException as http_exception:
2331
logger.error(http_exception)
2432
traceback.print_exception(http_exception)
2533
http_error = HttpError(http_exception.code, http_exception.description)
26-
return http_error.to_dict(), http_error.status
34+
resp_body = http_error.to_dict()
35+
resp_body["status_code"] = http_error.status
36+
resp_body["detail"] = http_error.message
37+
38+
return resp_body, http_error.status
2739
except Exception as e:
2840
logger.error(e)
2941
traceback.print_exception(e)
30-
return HttpError(500, "Internal Server Error").to_dict(), 500
42+
resp_body = HttpError(500, "Internal Server Error").to_dict()
43+
resp_body["status_code"] = 500
44+
resp_body["detail"] = "Internal Server Error"
45+
return resp_body, 500
3146

3247
return decorator

guardrails_api/utils/has_internet_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ def has_internet_connection() -> bool:
77
res.raise_for_status()
88
return True
99
except requests.ConnectionError:
10-
return False
10+
return False

guardrails_api/utils/openai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from guardrails.classes import ValidationOutcome
22

3+
34
def outcome_to_stream_response(validation_outcome: ValidationOutcome):
45
stream_chunk_template = {
56
"choices": [

guardrails_api/utils/trace_server_start_if_enabled.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def trace_server_start_if_enabled():
88
config = Credentials.from_rc_file()
99
if config.enable_metrics is True and has_internet_connection():
1010
from guardrails.utils.hub_telemetry_utils import HubTelemetry
11+
1112
HubTelemetry().create_new_span(
1213
"guardrails-api/start",
1314
[
@@ -21,4 +22,4 @@ def trace_server_start_if_enabled():
2122
],
2223
True,
2324
False,
24-
)
25+
)

tests/blueprints/test_guards.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ def test_validate__call(mocker):
549549

550550
del os.environ["PGHOST"]
551551

552+
552553
def test_validate__call_throws_validation_error(mocker):
553554
os.environ["PGHOST"] = "localhost"
554555

@@ -610,19 +611,24 @@ def test_validate__call_throws_validation_error(mocker):
610611
prompt="Hello world!",
611612
)
612613

613-
assert response == ('Test guard validation error', 400)
614+
assert response == (
615+
{"status_code": 400, "detail": "Test guard validation error"},
616+
400,
617+
)
614618

615619
del os.environ["PGHOST"]
616620

621+
617622
def test_openai_v1_chat_completions__raises_404(mocker):
618623
from guardrails_api.blueprints.guards import openai_v1_chat_completions
624+
619625
os.environ["PGHOST"] = "localhost"
620626
mock_guard = None
621627

622628
mock_request = MockRequest(
623629
"POST",
624630
json={
625-
"messages": [{"role":"user", "content":"Hello world!"}],
631+
"messages": [{"role": "user", "content": "Hello world!"}],
626632
},
627633
headers={"x-openai-api-key": "mock-key"},
628634
)
@@ -637,15 +643,16 @@ def test_openai_v1_chat_completions__raises_404(mocker):
637643

638644
response = openai_v1_chat_completions("My%20Guard's%20Name")
639645
assert response[1] == 404
640-
assert response[0]["message"] == 'NotFound'
641-
646+
assert response[0]["message"] == "NotFound"
642647

643648
mock_get_guard.assert_called_once_with("My Guard's Name")
644649

645650
del os.environ["PGHOST"]
646651

652+
647653
def test_openai_v1_chat_completions__call(mocker):
648654
from guardrails_api.blueprints.guards import openai_v1_chat_completions
655+
649656
os.environ["PGHOST"] = "localhost"
650657
mock_guard = MockGuardStruct()
651658
mock_outcome = ValidationOutcome(
@@ -664,7 +671,7 @@ def test_openai_v1_chat_completions__call(mocker):
664671
mock_request = MockRequest(
665672
"POST",
666673
json={
667-
"messages": [{"role":"user", "content":"Hello world!"}],
674+
"messages": [{"role": "user", "content": "Hello world!"}],
668675
},
669676
headers={"x-openai-api-key": "mock-key"},
670677
)
@@ -687,7 +694,7 @@ def test_openai_v1_chat_completions__call(mocker):
687694
)
688695
mock_status.return_value = "fail"
689696
mock_call = Call()
690-
mock_call.iterations= Stack(Iteration('some-id', 1))
697+
mock_call.iterations = Stack(Iteration("some-id", 1))
691698
mock_guard.history = Stack(mock_call)
692699

693700
response = openai_v1_chat_completions("My%20Guard's%20Name")
@@ -698,7 +705,7 @@ def test_openai_v1_chat_completions__call(mocker):
698705

699706
mock___call__.assert_called_once_with(
700707
num_reasks=0,
701-
messages=[{"role":"user", "content":"Hello world!"}],
708+
messages=[{"role": "user", "content": "Hello world!"}],
702709
)
703710

704711
assert response == {
@@ -716,4 +723,4 @@ def test_openai_v1_chat_completions__call(mocker):
716723
},
717724
}
718725

719-
del os.environ["PGHOST"]
726+
del os.environ["PGHOST"]

0 commit comments

Comments
 (0)