Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit 4a0b063

Browse files
committed
Merge branch 'main' of github.com:Chainlit/literalai-python into willy/eng-1754-fix-mistralai-instrumentation-for-100
2 parents 4c4ce14 + 5011c6d commit 4a0b063

File tree

5 files changed

+149
-80
lines changed

5 files changed

+149
-80
lines changed

literalai/api/__init__.py

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,6 @@
1616

1717
from typing_extensions import deprecated
1818

19-
from literalai.context import active_steps_var, active_thread_var
20-
from literalai.evaluation.dataset import Dataset, DatasetType
21-
from literalai.evaluation.dataset_experiment import (
22-
DatasetExperiment,
23-
DatasetExperimentItem,
24-
)
25-
from literalai.observability.filter import (
26-
generations_filters,
27-
generations_order_by,
28-
scores_filters,
29-
scores_order_by,
30-
steps_filters,
31-
steps_order_by,
32-
threads_filters,
33-
threads_order_by,
34-
users_filters,
35-
)
36-
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
37-
3819
from literalai.api.attachment_helpers import (
3920
AttachmentUpload,
4021
create_attachment_helper,
@@ -60,10 +41,12 @@
6041
get_generations_helper,
6142
)
6243
from literalai.api.prompt_helpers import (
44+
PromptRollout,
6345
create_prompt_helper,
6446
create_prompt_lineage_helper,
47+
get_prompt_ab_testing_helper,
6548
get_prompt_helper,
66-
promote_prompt_helper,
49+
update_prompt_ab_testing_helper,
6750
)
6851
from literalai.api.score_helpers import (
6952
ScoreUpdate,
@@ -98,29 +81,44 @@
9881
get_users_helper,
9982
update_user_helper,
10083
)
84+
from literalai.context import active_steps_var, active_thread_var
85+
from literalai.evaluation.dataset import Dataset, DatasetType
86+
from literalai.evaluation.dataset_experiment import (
87+
DatasetExperiment,
88+
DatasetExperimentItem,
89+
)
90+
from literalai.observability.filter import (
91+
generations_filters,
92+
generations_order_by,
93+
scores_filters,
94+
scores_order_by,
95+
steps_filters,
96+
steps_order_by,
97+
threads_filters,
98+
threads_order_by,
99+
users_filters,
100+
)
101+
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
101102

102103
if TYPE_CHECKING:
103104
from typing import Tuple # noqa: F401
104105

105106
import httpx
106107

107-
from literalai.my_types import (
108-
Environment,
109-
PaginatedResponse,
110-
)
108+
from literalai.my_types import Environment, PaginatedResponse
111109
from literalai.observability.generation import (
112-
GenerationMessage,
113-
CompletionGeneration,
114110
ChatGeneration,
111+
CompletionGeneration,
112+
GenerationMessage,
115113
)
116114
from literalai.observability.step import (
115+
Attachment,
116+
Score,
117+
ScoreDict,
118+
ScoreType,
117119
Step,
118120
StepDict,
119121
StepType,
120-
ScoreType,
121-
ScoreDict,
122-
Score,
123-
Attachment,
124122
)
125123

126124
logger = logging.getLogger(__name__)
@@ -1365,21 +1363,33 @@ def get_prompt(
13651363
else:
13661364
raise ValueError("Either the `id` or the `name` must be provided.")
13671365

1368-
def promote_prompt(self, name: str, version: int) -> str:
1366+
def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
13691367
"""
1370-
Promotes the prompt with name to target version.
1368+
Get the A/B testing configuration for a prompt lineage.
13711369
13721370
Args:
13731371
name (str): The name of the prompt lineage.
1374-
version (int): The version number to promote.
1375-
13761372
Returns:
1377-
str: The champion prompt ID.
1373+
List[PromptRollout]
13781374
"""
1379-
lineage = self.get_or_create_prompt_lineage(name)
1380-
lineage_id = lineage["id"]
1375+
return self.gql_helper(*get_prompt_ab_testing_helper(name=name))
13811376

1382-
return self.gql_helper(*promote_prompt_helper(lineage_id, version))
1377+
def update_prompt_ab_testing(
1378+
self, name: str, rollouts: List[PromptRollout]
1379+
) -> Dict:
1380+
"""
1381+
Update the A/B testing configuration for a prompt lineage.
1382+
1383+
Args:
1384+
name (str): The name of the prompt lineage.
1385+
rollouts (List[PromptRollout]): The percentage rollout for each prompt version.
1386+
1387+
Returns:
1388+
Dict
1389+
"""
1390+
return self.gql_helper(
1391+
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
1392+
)
13831393

13841394
# Misc API
13851395

@@ -2552,13 +2562,19 @@ async def get_prompt(
25522562

25532563
get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__
25542564

2555-
async def promote_prompt(self, name: str, version: int) -> str:
2556-
lineage = await self.get_or_create_prompt_lineage(name)
2557-
lineage_id = lineage["id"]
2565+
async def update_prompt_ab_testing(
2566+
self, name: str, rollouts: List[PromptRollout]
2567+
) -> Dict:
2568+
return await self.gql_helper(
2569+
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
2570+
)
2571+
2572+
update_prompt_ab_testing.__doc__ = LiteralAPI.update_prompt_ab_testing.__doc__
25582573

2559-
return await self.gql_helper(*promote_prompt_helper(lineage_id, version))
2574+
async def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
2575+
return await self.gql_helper(*get_prompt_ab_testing_helper(name=name))
25602576

2561-
promote_prompt.__doc__ = LiteralAPI.promote_prompt.__doc__
2577+
get_prompt_ab_testing.__doc__ = LiteralAPI.get_prompt_ab_testing.__doc__
25622578

25632579
# Misc API
25642580

literalai/api/gql.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,16 +1041,35 @@
10411041
}
10421042
"""
10431043

1044-
PROMOTE_PROMPT_VERSION = """mutation promotePromptVersion(
1045-
$lineageId: String!
1046-
$version: Int!
1044+
GET_PROMPT_AB_TESTING = """query getPromptLineageRollout($projectId: String, $lineageName: String!) {
1045+
promptLineageRollout(projectId: $projectId, lineageName: $lineageName) {
1046+
pageInfo {
1047+
startCursor
1048+
endCursor
1049+
}
1050+
edges {
1051+
node {
1052+
version
1053+
rollout
1054+
}
1055+
}
1056+
}
1057+
}
1058+
"""
1059+
1060+
UPDATE_PROMPT_AB_TESTING = """mutation updatePromptLineageRollout(
1061+
$projectId: String
1062+
$name: String!
1063+
$rollouts: [PromptVersionRolloutInput!]!
10471064
) {
1048-
promotePromptVersion(
1049-
lineageId: $lineageId
1050-
version: $version
1065+
updatePromptLineageRollout(
1066+
projectId: $projectId
1067+
name: $name
1068+
rollouts: $rollouts
10511069
) {
1052-
id
1053-
championId
1070+
ok
1071+
message
1072+
errorCode
10541073
}
10551074
}"""
10561075

literalai/api/prompt_helpers.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Dict, List, Optional
1+
from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict
22

33
from literalai.observability.generation import GenerationMessage
44
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
@@ -61,16 +61,31 @@ def process_response(response):
6161
return gql.GET_PROMPT_VERSION, description, variables, process_response
6262

6363

64-
def promote_prompt_helper(
65-
lineage_id: str,
66-
version: int,
64+
class PromptRollout(TypedDict):
65+
version: int
66+
rollout: int
67+
68+
69+
def get_prompt_ab_testing_helper(
70+
name: Optional[str] = None,
6771
):
68-
variables = {"lineageId": lineage_id, "version": version}
72+
variables = {"lineageName": name}
73+
74+
def process_response(response) -> List[PromptRollout]:
75+
response_data = response["data"]["promptLineageRollout"]
76+
return list(map(lambda x: x["node"], response_data["edges"]))
77+
78+
description = "get prompt A/B testing"
79+
80+
return gql.GET_PROMPT_AB_TESTING, description, variables, process_response
81+
82+
83+
def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]):
84+
variables = {"name": name, "rollouts": rollouts}
6985

70-
def process_response(response) -> str:
71-
prompt = response["data"]["promotePromptVersion"]
72-
return prompt["championId"] if prompt else None
86+
def process_response(response) -> Dict:
87+
return response["data"]["updatePromptLineageRollout"]
7388

74-
description = "promote prompt version"
89+
description = "update prompt A/B testing"
7590

76-
return gql.PROMOTE_PROMPT_VERSION, description, variables, process_response
91+
return gql.UPDATE_PROMPT_AB_TESTING, description, variables, process_response

literalai/prompt_engineering/prompt.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from importlib.metadata import version
44
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
55

6-
from typing_extensions import deprecated, TypedDict
7-
86
import chevron
7+
from typing_extensions import TypedDict, deprecated
98

109
if TYPE_CHECKING:
1110
from literalai.api import LiteralAPI
@@ -117,13 +116,6 @@ def from_dict(cls, api: "LiteralAPI", prompt_dict: PromptDict) -> "Prompt":
117116
variables_default_values=prompt_dict.get("variablesDefaultValues"),
118117
)
119118

120-
def promote(self) -> "Prompt":
121-
"""
122-
Promotes this prompt to champion.
123-
"""
124-
self.api.promote_prompt(self.name, self.version)
125-
return self
126-
127119
def format_messages(self, **kwargs: Any) -> List[Any]:
128120
"""
129121
Formats the prompt's template messages with the given variables.

tests/e2e/test_e2e.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import secrets
33
import time
44
import uuid
5+
from typing import List
56

67
import pytest
78

89
from literalai import AsyncLiteralClient, LiteralClient
910
from literalai.context import active_steps_var
10-
from literalai.observability.generation import ChatGeneration
11+
from literalai.observability.generation import ChatGeneration, GenerationMessage
1112
from literalai.observability.thread import Thread
1213

1314
"""
@@ -384,7 +385,6 @@ def step_decorated():
384385
async def test_nested_run_steps(
385386
self, client: LiteralClient, async_client: AsyncLiteralClient
386387
):
387-
388388
@async_client.run(name="foo")
389389
def run_decorated():
390390
s = async_client.get_current_step()
@@ -627,16 +627,43 @@ async def test_prompt(self, async_client: AsyncLiteralClient):
627627
assert messages[0]["content"] == expected
628628

629629
@pytest.mark.timeout(5)
630-
async def test_champion_prompt(self, client: LiteralClient):
631-
new_prompt = client.api.get_or_create_prompt(
632-
name="Python SDK E2E Tests",
633-
template_messages=[{"role": "user", "content": "Hello"}],
630+
async def test_prompt_ab_testing(self, client: LiteralClient):
631+
prompt_name = "Python SDK E2E Tests"
632+
633+
v0: List[GenerationMessage] = [{"role": "user", "content": "Hello"}]
634+
v1: List[GenerationMessage] = [{"role": "user", "content": "Hello 2"}]
635+
636+
prompt_v0 = client.api.get_or_create_prompt(
637+
name=prompt_name,
638+
template_messages=v0,
634639
)
635-
new_prompt.promote()
636640

637-
prompt = client.api.get_prompt(name="Python SDK E2E Tests")
638-
assert prompt is not None
639-
assert prompt.version == new_prompt.version
641+
client.api.update_prompt_ab_testing(
642+
prompt_v0.name, rollouts=[{"version": 0, "rollout": 100}]
643+
)
644+
645+
ab_testing = client.api.get_prompt_ab_testing(name=prompt_v0.name)
646+
assert len(ab_testing) == 1
647+
assert ab_testing[0]["version"] == 0
648+
assert ab_testing[0]["rollout"] == 100
649+
650+
prompt_v1 = client.api.get_or_create_prompt(
651+
name=prompt_name,
652+
template_messages=v1,
653+
)
654+
655+
client.api.update_prompt_ab_testing(
656+
name=prompt_v1.name,
657+
rollouts=[{"version": 0, "rollout": 60}, {"version": 1, "rollout": 40}],
658+
)
659+
660+
ab_testing = client.api.get_prompt_ab_testing(name=prompt_v1.name)
661+
662+
assert len(ab_testing) == 2
663+
assert ab_testing[0]["version"] == 0
664+
assert ab_testing[0]["rollout"] == 60
665+
assert ab_testing[1]["version"] == 1
666+
assert ab_testing[1]["rollout"] == 40
640667

641668
@pytest.mark.timeout(5)
642669
async def test_gracefulness(self, broken_client: LiteralClient):

0 commit comments

Comments
 (0)