Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit 87274ae

Browse files
authored
feat: add caching if prompt request fails (#148)
1 parent 5f4c92a commit 87274ae

File tree

11 files changed

+261
-52
lines changed

11 files changed

+261
-52
lines changed

literalai/api/asynchronous.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from typing_extensions import deprecated
55
from typing import (
6-
TYPE_CHECKING,
76
Any,
87
Callable,
98
Dict,
@@ -106,9 +105,6 @@
106105
from literalai.observability.thread import Thread
107106
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
108107

109-
if TYPE_CHECKING:
110-
from typing import Tuple # noqa: F401
111-
112108
import httpx
113109

114110
from literalai.my_types import PaginatedResponse, User
@@ -145,7 +141,7 @@ class AsyncLiteralAPI(BaseLiteralAPI):
145141
R = TypeVar("R")
146142

147143
async def make_gql_call(
148-
self, description: str, query: str, variables: Dict[str, Any]
144+
self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10
149145
) -> Dict:
150146
def raise_error(error):
151147
logger.error(f"Failed to {description}: {error}")
@@ -158,7 +154,7 @@ def raise_error(error):
158154
self.graphql_endpoint,
159155
json={"query": query, "variables": variables},
160156
headers=self.headers,
161-
timeout=10,
157+
timeout=timeout,
162158
)
163159

164160
try:
@@ -179,13 +175,12 @@ def raise_error(error):
179175

180176
if json.get("data"):
181177
if isinstance(json["data"], dict):
182-
for _, value in json["data"].items():
178+
for value in json["data"].values():
183179
if value and value.get("ok") is False:
184180
raise_error(
185181
f"""Failed to {description}: {
186182
value.get('message')}"""
187183
)
188-
189184
return json
190185

191186
async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
@@ -211,15 +206,15 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
211206
f"""Failed to parse JSON response: {
212207
e}, content: {response.content!r}"""
213208
)
214-
215209
async def gql_helper(
216210
self,
217211
query: str,
218212
description: str,
219213
variables: Dict,
220214
process_response: Callable[..., R],
215+
timeout: Optional[int] = 10,
221216
) -> R:
222-
response = await self.make_gql_call(description, query, variables)
217+
response = await self.make_gql_call(description, query, variables, timeout)
223218
return process_response(response)
224219

225220
##################################################################################
@@ -447,7 +442,7 @@ async def upload_file(
447442
# Prepare form data
448443
form_data = (
449444
{}
450-
) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]]
445+
) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
451446
for field_name, field_value in fields.items():
452447
form_data[field_name] = (None, field_value)
453448

@@ -838,16 +833,32 @@ async def get_prompt(
838833
id: Optional[str] = None,
839834
name: Optional[str] = None,
840835
version: Optional[int] = None,
841-
) -> "Prompt":
836+
) -> Prompt:
837+
if not (id or name):
838+
raise ValueError("At least the `id` or the `name` must be provided.")
839+
842840
sync_api = LiteralAPI(self.api_key, self.url)
843-
if id:
844-
return await self.gql_helper(*get_prompt_helper(sync_api, id=id))
845-
elif name:
846-
return await self.gql_helper(
847-
*get_prompt_helper(sync_api, name=name, version=version)
848-
)
849-
else:
850-
raise ValueError("Either the `id` or the `name` must be provided.")
841+
get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
842+
api=sync_api, id=id, name=name, version=version, cache=self.cache
843+
)
844+
845+
try:
846+
if id:
847+
prompt = await self.gql_helper(
848+
get_prompt_query, description, variables, process_response, timeout
849+
)
850+
elif name:
851+
prompt = await self.gql_helper(
852+
get_prompt_query, description, variables, process_response, timeout
853+
)
854+
855+
return prompt
856+
857+
except Exception as e:
858+
if cached_prompt:
859+
logger.warning("Failed to get prompt from API, returning cached prompt")
860+
return cached_prompt
861+
raise e
851862

852863
async def update_prompt_ab_testing(
853864
self, name: str, rollouts: List["PromptRollout"]

literalai/api/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from literalai.my_types import Environment
1515

16+
from literalai.cache.shared_cache import SharedCache
1617
from literalai.evaluation.dataset import DatasetType
1718
from literalai.evaluation.dataset_experiment import (
1819
DatasetExperimentItem,
@@ -95,6 +96,8 @@ def __init__(
9596
self.graphql_endpoint = self.url + "/api/graphql"
9697
self.rest_endpoint = self.url + "/api"
9798

99+
self.cache = SharedCache()
100+
98101
@property
99102
def headers(self):
100103
from literalai.version import __version__
@@ -1011,9 +1014,9 @@ def get_prompt(
10111014
"""
10121015
Gets a prompt either by:
10131016
- `id`
1014-
- or `name` and (optional) `version`
1017+
- `name` and (optional) `version`
10151018
1016-
Either the `id` or the `name` must be provided.
1019+
At least the `id` or the `name` must be passed to the function.
10171020
If both are provided, the `id` is used.
10181021
10191022
Args:

literalai/api/helpers/prompt_helpers.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict
1+
from typing import TYPE_CHECKING, Optional, TypedDict, Callable
22

33
from literalai.observability.generation import GenerationMessage
44
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
55

6+
from literalai.cache.prompt_helpers import put_prompt
7+
68
if TYPE_CHECKING:
79
from literalai.api import LiteralAPI
10+
from literalai.cache.shared_cache import SharedCache
811

912
from literalai.api.helpers import gql
1013

@@ -36,9 +39,9 @@ def process_response(response):
3639
def create_prompt_helper(
3740
api: "LiteralAPI",
3841
lineage_id: str,
39-
template_messages: List[GenerationMessage],
42+
template_messages: list[GenerationMessage],
4043
settings: Optional[ProviderSettings] = None,
41-
tools: Optional[List[Dict]] = None,
44+
tools: Optional[list[dict]] = None,
4245
):
4346
variables = {
4447
"lineageId": lineage_id,
@@ -56,28 +59,52 @@ def process_response(response):
5659
return gql.CREATE_PROMPT_VERSION, description, variables, process_response
5760

5861

62+
def get_prompt_cache_key(id: Optional[str], name: Optional[str], version: Optional[int]) -> str:
63+
if id:
64+
return id
65+
elif name and version:
66+
return f"{name}-{version}"
67+
elif name:
68+
return name
69+
else:
70+
raise ValueError("Either the `id` or the `name` must be provided.")
71+
72+
5973
def get_prompt_helper(
6074
api: "LiteralAPI",
6175
id: Optional[str] = None,
6276
name: Optional[str] = None,
6377
version: Optional[int] = 0,
64-
):
78+
cache: Optional["SharedCache"] = None,
79+
) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]:
80+
"""Helper function for getting prompts with caching logic"""
81+
82+
cached_prompt = None
83+
timeout = 10
84+
85+
if cache:
86+
cached_prompt = cache.get(get_prompt_cache_key(id, name, version))
87+
timeout = 1 if cached_prompt else timeout
88+
6589
variables = {"id": id, "name": name, "version": version}
6690

6791
def process_response(response):
68-
prompt = response["data"]["promptVersion"]
69-
return Prompt.from_dict(api, prompt) if prompt else None
92+
prompt_version = response["data"]["promptVersion"]
93+
prompt = Prompt.from_dict(api, prompt_version) if prompt_version else None
94+
if cache and prompt:
95+
put_prompt(cache, prompt)
96+
return prompt
7097

7198
description = "get prompt"
7299

73-
return gql.GET_PROMPT_VERSION, description, variables, process_response
100+
return gql.GET_PROMPT_VERSION, description, variables, process_response, timeout, cached_prompt
74101

75102

76103
def create_prompt_variant_helper(
77104
from_lineage_id: Optional[str] = None,
78-
template_messages: List[GenerationMessage] = [],
105+
template_messages: list[GenerationMessage] = [],
79106
settings: Optional[ProviderSettings] = None,
80-
tools: Optional[List[Dict]] = None,
107+
tools: Optional[list[dict]] = None,
81108
):
82109
variables = {
83110
"fromLineageId": from_lineage_id,
@@ -105,7 +132,7 @@ def get_prompt_ab_testing_helper(
105132
):
106133
variables = {"lineageName": name}
107134

108-
def process_response(response) -> List[PromptRollout]:
135+
def process_response(response) -> list[PromptRollout]:
109136
response_data = response["data"]["promptLineageRollout"]
110137
return list(map(lambda x: x["node"], response_data["edges"]))
111138

@@ -114,10 +141,10 @@ def process_response(response) -> List[PromptRollout]:
114141
return gql.GET_PROMPT_AB_TESTING, description, variables, process_response
115142

116143

117-
def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]):
144+
def update_prompt_ab_testing_helper(name: str, rollouts: list[PromptRollout]):
118145
variables = {"name": name, "rollouts": rollouts}
119146

120-
def process_response(response) -> Dict:
147+
def process_response(response) -> dict:
121148
return response["data"]["updatePromptLineageRollout"]
122149

123150
description = "update prompt A/B testing"

literalai/api/synchronous.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from typing_extensions import deprecated
55
from typing import (
6-
TYPE_CHECKING,
76
Any,
87
Callable,
98
Dict,
@@ -105,9 +104,6 @@
105104
from literalai.observability.thread import Thread
106105
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
107106

108-
if TYPE_CHECKING:
109-
from typing import Tuple # noqa: F401
110-
111107
import httpx
112108

113109
from literalai.my_types import PaginatedResponse, User
@@ -144,8 +140,8 @@ class LiteralAPI(BaseLiteralAPI):
144140
R = TypeVar("R")
145141

146142
def make_gql_call(
147-
self, description: str, query: str, variables: Dict[str, Any]
148-
) -> Dict:
143+
self, description: str, query: str, variables: dict[str, Any], timeout: Optional[int] = 10
144+
) -> dict:
149145
def raise_error(error):
150146
logger.error(f"Failed to {description}: {error}")
151147
raise Exception(error)
@@ -156,7 +152,7 @@ def raise_error(error):
156152
self.graphql_endpoint,
157153
json={"query": query, "variables": variables},
158154
headers=self.headers,
159-
timeout=10,
155+
timeout=timeout,
160156
)
161157

162158
try:
@@ -177,7 +173,7 @@ def raise_error(error):
177173

178174
if json.get("data"):
179175
if isinstance(json["data"], dict):
180-
for _, value in json["data"].items():
176+
for value in json["data"].values():
181177
if value and value.get("ok") is False:
182178
raise_error(
183179
f"""Failed to {description}: {
@@ -186,7 +182,6 @@ def raise_error(error):
186182

187183
return json
188184

189-
190185
def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
191186
with httpx.Client(follow_redirects=True) as client:
192187
response = client.post(
@@ -217,8 +212,9 @@ def gql_helper(
217212
description: str,
218213
variables: Dict,
219214
process_response: Callable[..., R],
215+
timeout: Optional[int] = None,
220216
) -> R:
221-
response = self.make_gql_call(description, query, variables)
217+
response = self.make_gql_call(description, query, variables, timeout)
222218
return process_response(response)
223219

224220
##################################################################################
@@ -441,7 +437,7 @@ def upload_file(
441437
# Prepare form data
442438
form_data = (
443439
{}
444-
) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]]
440+
) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
445441
for field_name, field_value in fields.items():
446442
form_data[field_name] = (None, field_value)
447443

@@ -805,12 +801,27 @@ def get_prompt(
805801
name: Optional[str] = None,
806802
version: Optional[int] = None,
807803
) -> "Prompt":
808-
if id:
809-
return self.gql_helper(*get_prompt_helper(self, id=id))
810-
elif name:
811-
return self.gql_helper(*get_prompt_helper(self, name=name, version=version))
812-
else:
813-
raise ValueError("Either the `id` or the `name` must be provided.")
804+
if not (id or name):
805+
raise ValueError("At least the `id` or the `name` must be provided.")
806+
807+
get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
808+
api=self,id=id, name=name, version=version, cache=self.cache
809+
)
810+
811+
try:
812+
if id:
813+
prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout)
814+
elif name:
815+
prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout)
816+
817+
return prompt
818+
819+
except Exception as e:
820+
if cached_prompt:
821+
logger.warning("Failed to get prompt from API, returning cached prompt")
822+
return cached_prompt
823+
824+
raise e
814825

815826
def create_prompt_variant(
816827
self,

literalai/cache/__init__.py

Whitespace-only changes.

literalai/cache/prompt_helpers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from literalai.prompt_engineering.prompt import Prompt
2+
from literalai.cache.shared_cache import SharedCache
3+
4+
5+
def put_prompt(cache: SharedCache, prompt: Prompt):
6+
cache.put(prompt.id, prompt)
7+
cache.put(prompt.name, prompt)
8+
cache.put(f"{prompt.name}-{prompt.version}", prompt)

0 commit comments

Comments
 (0)