Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit d2cae60

Browse files
feat: prompt variant (#142)
* feat: prompt variant * fix: mypy errors on mistralai * fix: rename + version bump --------- Co-authored-by: Hugues de Saxcé <hugues.de.saxce@protonmail.com>
1 parent ef9b05d commit d2cae60

File tree

10 files changed

+183
-28
lines changed

10 files changed

+183
-28
lines changed

literalai/api/__init__.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@
4444
PromptRollout,
4545
create_prompt_helper,
4646
create_prompt_lineage_helper,
47+
create_prompt_variant_helper,
4748
get_prompt_ab_testing_helper,
4849
get_prompt_helper,
50+
get_prompt_lineage_helper,
4951
update_prompt_ab_testing_helper,
5052
)
5153
from literalai.api.score_helpers import (
@@ -144,7 +146,6 @@ def handle_bytes(item):
144146

145147

146148
class BaseLiteralAPI:
147-
148149
def __init__(
149150
self,
150151
api_key: Optional[str] = None,
@@ -201,7 +202,6 @@ class LiteralAPI(BaseLiteralAPI):
201202
def make_gql_call(
202203
self, description: str, query: str, variables: Dict[str, Any]
203204
) -> Dict:
204-
205205
def raise_error(error):
206206
logger.error(f"Failed to {description}: {error}")
207207
raise Exception(error)
@@ -1141,7 +1141,7 @@ def create_experiment(
11411141
self,
11421142
name: str,
11431143
dataset_id: Optional[str] = None,
1144-
prompt_id: Optional[str] = None,
1144+
prompt_variant_id: Optional[str] = None,
11451145
params: Optional[Dict] = None,
11461146
) -> "DatasetExperiment":
11471147
"""
@@ -1150,7 +1150,7 @@ def create_experiment(
11501150
Args:
11511151
name (str): The name of the experiment.
11521152
dataset_id (Optional[str]): The unique identifier of the dataset.
1153-
prompt_id (Optional[str]): The identifier of the prompt associated with the experiment.
1153+
prompt_variant_id (Optional[str]): The identifier of the prompt variant to associate to the experiment.
11541154
params (Optional[Dict]): Additional parameters for the experiment.
11551155
11561156
Returns:
@@ -1161,7 +1161,7 @@ def create_experiment(
11611161
api=self,
11621162
name=name,
11631163
dataset_id=dataset_id,
1164-
prompt_id=prompt_id,
1164+
prompt_variant_id=prompt_variant_id,
11651165
params=params,
11661166
)
11671167
)
@@ -1369,6 +1369,34 @@ def get_prompt(
13691369
else:
13701370
raise ValueError("Either the `id` or the `name` must be provided.")
13711371

1372+
def create_prompt_variant(
1373+
self,
1374+
name: str,
1375+
template_messages: List[GenerationMessage],
1376+
settings: Optional[ProviderSettings] = None,
1377+
tools: Optional[List[Dict]] = None,
1378+
) -> Optional[str]:
1379+
"""
1380+
Creates a prompt variation for an experiment.
1381+
This variation is not an official version until manually saved.
1382+
1383+
Args:
1384+
name (str): The name of the prompt to retrieve or create.
1385+
template_messages (List[GenerationMessage]): A list of template messages for the prompt.
1386+
settings (Optional[Dict]): Optional settings for the prompt.
1387+
tools (Optional[List[Dict]]): Optional tool options for the model
1388+
1389+
Returns:
1390+
prompt_variant_id: The prompt variant id to link with the experiment.
1391+
"""
1392+
lineage = self.gql_helper(*get_prompt_lineage_helper(name))
1393+
lineage_id = lineage["id"] if lineage else None
1394+
return self.gql_helper(
1395+
*create_prompt_variant_helper(
1396+
lineage_id, template_messages, settings, tools
1397+
)
1398+
)
1399+
13721400
def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
13731401
"""
13741402
Get the A/B testing configuration for a prompt lineage.
@@ -2351,7 +2379,7 @@ async def create_experiment(
23512379
self,
23522380
name: str,
23532381
dataset_id: Optional[str] = None,
2354-
prompt_id: Optional[str] = None,
2382+
prompt_variant_id: Optional[str] = None,
23552383
params: Optional[Dict] = None,
23562384
) -> "DatasetExperiment":
23572385
sync_api = LiteralAPI(self.api_key, self.url)
@@ -2361,7 +2389,7 @@ async def create_experiment(
23612389
api=sync_api,
23622390
name=name,
23632391
dataset_id=dataset_id,
2364-
prompt_id=prompt_id,
2392+
prompt_variant_id=prompt_variant_id,
23652393
params=params,
23662394
)
23672395
)
@@ -2529,6 +2557,36 @@ async def create_prompt(
25292557
):
25302558
return await self.get_or_create_prompt(name, template_messages, settings)
25312559

2560+
async def create_prompt_variant(
2561+
self,
2562+
name: str,
2563+
template_messages: List[GenerationMessage],
2564+
settings: Optional[ProviderSettings] = None,
2565+
tools: Optional[List[Dict]] = None,
2566+
) -> Optional[str]:
2567+
"""
2568+
Creates a prompt variation for an experiment.
2569+
This variation is not an official version until manually saved.
2570+
2571+
Args:
2572+
name (str): The name of the prompt to retrieve or create.
2573+
template_messages (List[GenerationMessage]): A list of template messages for the prompt.
2574+
settings (Optional[Dict]): Optional settings for the prompt.
2575+
tools (Optional[List[Dict]]): Optional tool options for the model
2576+
2577+
Returns:
2578+
prompt_variant_id: The prompt variant id to link with the experiment.
2579+
"""
2580+
lineage = await self.gql_helper(*get_prompt_lineage_helper(name))
2581+
lineage_id = lineage["id"] if lineage else None
2582+
return await self.gql_helper(
2583+
*create_prompt_variant_helper(
2584+
lineage_id, template_messages, settings, tools
2585+
)
2586+
)
2587+
2588+
create_prompt_variant.__doc__ = LiteralAPI.create_prompt_variant.__doc__
2589+
25322590
async def get_prompt(
25332591
self,
25342592
id: Optional[str] = None,

literalai/api/dataset_helpers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from typing import TYPE_CHECKING, Dict, Optional
22

33
from literalai.api import gql
4-
54
from literalai.evaluation.dataset import Dataset, DatasetType
6-
from literalai.evaluation.dataset_experiment import DatasetExperiment, DatasetExperimentItem
5+
from literalai.evaluation.dataset_experiment import (
6+
DatasetExperiment,
7+
DatasetExperimentItem,
8+
)
79
from literalai.evaluation.dataset_item import DatasetItem
810

911
if TYPE_CHECKING:
1012
from literalai.api import LiteralAPI
1113

1214

13-
1415
def create_dataset_helper(
1516
api: "LiteralAPI",
1617
name: str,
@@ -98,13 +99,13 @@ def create_experiment_helper(
9899
api: "LiteralAPI",
99100
name: str,
100101
dataset_id: Optional[str] = None,
101-
prompt_id: Optional[str] = None,
102+
prompt_variant_id: Optional[str] = None,
102103
params: Optional[Dict] = None,
103104
):
104105
variables = {
105106
"datasetId": dataset_id,
106107
"name": name,
107-
"promptId": prompt_id,
108+
"promptExperimentId": prompt_variant_id,
108109
"params": params,
109110
}
110111

literalai/api/gql.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -833,18 +833,19 @@
833833
mutation CreateDatasetExperiment(
834834
$name: String!
835835
$datasetId: String
836-
$promptId: String
836+
$promptExperimentId: String
837837
$params: Json
838838
) {
839839
createDatasetExperiment(
840840
name: $name
841841
datasetId: $datasetId
842-
promptId: $promptId
842+
promptExperimentId: $promptExperimentId
843843
params: $params
844844
) {
845845
id
846846
name
847847
datasetId
848+
promptExperimentId
848849
params
849850
}
850851
}
@@ -991,6 +992,16 @@
991992
}
992993
}"""
993994

995+
GET_PROMPT_LINEAGE = """query promptLineage(
996+
$name: String!
997+
) {
998+
promptLineage(
999+
name: $name
1000+
) {
1001+
id
1002+
}
1003+
}"""
1004+
9941005
CREATE_PROMPT_VERSION = """mutation createPromptVersion(
9951006
$lineageId: String!
9961007
$versionDesc: String
@@ -1021,6 +1032,38 @@
10211032
}
10221033
}"""
10231034

1035+
CREATE_PROMPT_VARIANT = """mutation createPromptExperiment(
1036+
$fromLineageId: String
1037+
$fromVersion: Int
1038+
$scoreTemplateId: String
1039+
$templateMessages: Json
1040+
$settings: Json
1041+
$tools: Json
1042+
$variables: Json
1043+
) {
1044+
createPromptExperiment(
1045+
fromLineageId: $fromLineageId
1046+
fromVersion: $fromVersion
1047+
scoreTemplateId: $scoreTemplateId
1048+
templateMessages: $templateMessages
1049+
settings: $settings
1050+
tools: $tools
1051+
variables: $variables
1052+
) {
1053+
id
1054+
fromLineageId
1055+
fromVersion
1056+
scoreTemplateId
1057+
projectId
1058+
projectUserId
1059+
tools
1060+
settings
1061+
variables
1062+
templateMessages
1063+
}
1064+
}
1065+
"""
1066+
10241067
GET_PROMPT_VERSION = """
10251068
query GetPrompt($id: String, $name: String, $version: Int) {
10261069
promptVersion(id: $id, name: $name, version: $version) {

literalai/api/prompt_helpers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ def process_response(response):
2121
return gql.CREATE_PROMPT_LINEAGE, description, variables, process_response
2222

2323

24+
def get_prompt_lineage_helper(name: str):
25+
variables = {"name": name}
26+
27+
def process_response(response):
28+
prompt = response["data"]["promptLineage"]
29+
return prompt
30+
31+
description = "get prompt lineage"
32+
33+
return gql.GET_PROMPT_LINEAGE, description, variables, process_response
34+
35+
2436
def create_prompt_helper(
2537
api: "LiteralAPI",
2638
lineage_id: str,
@@ -61,6 +73,28 @@ def process_response(response):
6173
return gql.GET_PROMPT_VERSION, description, variables, process_response
6274

6375

76+
def create_prompt_variant_helper(
77+
from_lineage_id: Optional[str] = None,
78+
template_messages: List[GenerationMessage] = [],
79+
settings: Optional[ProviderSettings] = None,
80+
tools: Optional[List[Dict]] = None,
81+
):
82+
variables = {
83+
"fromLineageId": from_lineage_id,
84+
"templateMessages": template_messages,
85+
"settings": settings,
86+
"tools": tools,
87+
}
88+
89+
def process_response(response):
90+
variant = response["data"]["createPromptExperiment"]
91+
return variant["id"] if variant else None
92+
93+
description = "create prompt variant"
94+
95+
return gql.CREATE_PROMPT_VARIANT, description, variables, process_response
96+
97+
6498
class PromptRollout(TypedDict):
6599
version: int
66100
rollout: int

literalai/evaluation/dataset.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from dataclasses import dataclass, field
22
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, cast
33

4-
from literalai.my_types import Utils
5-
64
from typing_extensions import TypedDict
75

6+
from literalai.my_types import Utils
7+
88
if TYPE_CHECKING:
99
from literalai.api import LiteralAPI
1010

@@ -101,17 +101,23 @@ def create_item(
101101
return dataset_item
102102

103103
def create_experiment(
104-
self, name: str, prompt_id: Optional[str] = None, params: Optional[Dict] = None
104+
self,
105+
name: str,
106+
prompt_variant_id: Optional[str] = None,
107+
params: Optional[Dict] = None,
105108
) -> DatasetExperiment:
106109
"""
107110
Creates a new dataset experiment based on this dataset.
108111
:param name: The name of the experiment .
109-
:param prompt_id: The Prompt ID used on LLM calls (optional).
112+
:param prompt_variant_id: The Prompt variant ID to experiment on.
110113
:param params: The params used on the experiment.
111114
:return: The created DatasetExperiment instance.
112115
"""
113116
experiment = self.api.create_experiment(
114-
name=name, dataset_id=self.id, prompt_id=prompt_id, params=params
117+
name=name,
118+
dataset_id=self.id,
119+
prompt_variant_id=prompt_variant_id,
120+
params=params,
115121
)
116122
return experiment
117123

literalai/evaluation/dataset_experiment.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class DatasetExperimentDict(TypedDict, total=False):
5959
name: str
6060
datasetId: str
6161
params: Dict
62-
promptId: Optional[str]
62+
promptExperimentId: Optional[str]
6363
items: Optional[List[DatasetExperimentItemDict]]
6464

6565

@@ -71,7 +71,7 @@ class DatasetExperiment(Utils):
7171
name: str
7272
dataset_id: Optional[str]
7373
params: Optional[Dict]
74-
prompt_id: Optional[str] = None
74+
prompt_variant_id: Optional[str] = None
7575
items: List[DatasetExperimentItem] = field(default_factory=lambda: [])
7676

7777
def log(self, item_dict: DatasetExperimentItemDict) -> DatasetExperimentItem:
@@ -97,7 +97,7 @@ def to_dict(self):
9797
"createdAt": self.created_at,
9898
"name": self.name,
9999
"datasetId": self.dataset_id,
100-
"promptId": self.prompt_id,
100+
"promptExperimentId": self.prompt_variant_id,
101101
"params": self.params,
102102
"items": [item.to_dict() for item in self.items],
103103
}
@@ -116,6 +116,6 @@ def from_dict(
116116
name=dataset_experiment.get("name", ""),
117117
dataset_id=dataset_experiment.get("datasetId", ""),
118118
params=dataset_experiment.get("params"),
119-
prompt_id=dataset_experiment.get("promptId"),
119+
prompt_variant_id=dataset_experiment.get("promptExperimentId"),
120120
items=[DatasetExperimentItem.from_dict(item) for item in items],
121121
)

0 commit comments

Comments
 (0)