Skip to content

Commit a4499da

Browse files
committed
feat: Pass in metric version and name to experiments
1 parent c1c9dc9 commit a4499da

File tree

6 files changed

+278
-21
lines changed

6 files changed

+278
-21
lines changed

src/galileo/experiments.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from galileo.resources.models import ExperimentResponse, HTTPValidationError, PromptRunSettings, ScorerConfig, TaskType
2020
from galileo.schema.datasets import DatasetRecord
21-
from galileo.schema.metrics import LocalMetricConfig
21+
from galileo.schema.metrics import LocalMetricConfig, Metric
2222
from galileo.scorers import Scorers, ScorerSettings
2323
from galileo.utils.datasets import load_dataset_and_records
2424

@@ -88,27 +88,43 @@ def list(self, project_id: str) -> Optional[Union[HTTPValidationError, list["Exp
8888

8989
@staticmethod
9090
def create_metric_configs(
91-
project_id: str, experiment_id: str, metrics: builtins.list[Union[str, LocalMetricConfig]]
91+
project_id: str, experiment_id: str, metrics: builtins.list[Union[str, LocalMetricConfig, Metric]]
9292
) -> tuple[builtins.list[ScorerConfig], builtins.list[LocalMetricConfig]]:
9393
scorers = []
94-
scorer_names = [metric for metric in metrics if isinstance(metric, str)]
95-
if scorer_names:
96-
all_scorers = Scorers().list()
97-
known_metrics = {metric.name: metric for metric in all_scorers}
98-
unknown_metrics = []
99-
for metric in scorer_names:
100-
if metric in known_metrics:
101-
scorers.append(ScorerConfig.from_dict(known_metrics[metric].to_dict()))
94+
95+
local_metric_configs: builtins.list[LocalMetricConfig] = []
96+
97+
all_scorers = Scorers().list()
98+
known_metrics = {metric.name: metric for metric in all_scorers}
99+
100+
unknown_metrics = []
101+
102+
for metric in metrics:
103+
if isinstance(metric, LocalMetricConfig):
104+
local_metric_configs.append(metric)
105+
continue
106+
else:
107+
name = metric.name if isinstance(metric, Metric) else metric
108+
version = metric.version if isinstance(metric, Metric) else None
109+
110+
if name in known_metrics:
111+
raw_metric_dict = known_metrics[name].to_dict()
112+
113+
# Set the version on the ScorerConfig if provided
114+
if version is not None:
115+
raw_version = Scorers().get_scorer_version(scorer_id=raw_metric_dict["id"], version=version)
116+
raw_metric_dict["scorer_version"] = raw_version.to_dict()
117+
scorers.append(ScorerConfig.from_dict(raw_metric_dict))
102118
else:
103-
unknown_metrics.append(metric)
104-
if unknown_metrics:
105-
raise ValueError(
106-
"One or more non-existent metrics are specified:"
107-
+ ", ".join(f"'{metric}'" for metric in unknown_metrics)
108-
)
109-
ScorerSettings().create(project_id=project_id, run_id=experiment_id, scorers=scorers)
119+
unknown_metrics.append(name)
120+
121+
if unknown_metrics:
122+
raise ValueError(
123+
"One or more non-existent metrics are specified:"
124+
+ ", ".join(f"'{metric}'" for metric in unknown_metrics)
125+
)
110126

111-
local_metric_configs = [metric for metric in metrics if isinstance(metric, LocalMetricConfig)]
127+
ScorerSettings().create(project_id=project_id, run_id=experiment_id, scorers=scorers)
112128

113129
return scorers, local_metric_configs
114130

@@ -212,7 +228,7 @@ def run_experiment(
212228
dataset: Optional[Union[Dataset, list[dict[str, str]], str]] = None,
213229
dataset_id: Optional[str] = None,
214230
dataset_name: Optional[str] = None,
215-
metrics: Optional[list[Union[str, LocalMetricConfig]]] = None,
231+
metrics: Optional[list[Union[str, LocalMetricConfig, Metric]]] = None,
216232
function: Optional[Callable] = None,
217233
) -> Any:
218234
"""

src/galileo/schema/metrics.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Callable, Generic, Optional, TypeVar, Union
22

3-
from pydantic import BaseModel, Field, ValidationError, field_validator
3+
from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
44
from pydantic_core.core_schema import ValidationInfo
55

66
from galileo_core.schemas.logging.span import Span
77
from galileo_core.schemas.logging.step import StepType
88
from galileo_core.schemas.logging.trace import Trace
99
from galileo_core.schemas.shared.metric import MetricValueType
10+
from galileo_core.schemas.shared.scorers.scorer_name import ScorerName
1011

1112
MetricType = TypeVar("MetricType", bound=MetricValueType)
1213

@@ -33,3 +34,23 @@ def set_aggregatable_types(cls, value: list[StepType], info: ValidationInfo) ->
3334
if step_type not in [StepType.workflow, StepType.trace]:
3435
raise ValidationError("aggregatable_types can only contain trace or workflow steps")
3536
return value
37+
38+
39+
class Metric(BaseModel):
40+
name: Union[str] = Field(
41+
description="The name of the metric you want to run a specific version of (ie: 'Sentence Density')."
42+
)
43+
version: int | None = Field(
44+
default=None,
45+
description="The version of the metric (ie: 1, 2, 3, etc.). If None is provided, the 'default' version will be used.",
46+
)
47+
48+
@model_validator(mode="after")
49+
def validate_name_and_version(self) -> "Metric":
50+
preset_metric_names = [scorer.value for scorer in ScorerName]
51+
if self.name in preset_metric_names:
52+
if self.version is not None:
53+
raise ValueError(
54+
f"Galileo metric's '{self.name}' do not support versioning at this time. Please use the default version."
55+
)
56+
return self

src/galileo/scorers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from typing import Optional, Union
2+
from uuid import UUID
23

34
from galileo.base import BaseClientModel
4-
from galileo.resources.api.data import list_scorers_with_filters_scorers_list_post
5+
from galileo.resources.api.data import (
6+
get_scorer_version_or_latest_scorers_scorer_id_version_get,
7+
list_scorers_with_filters_scorers_list_post,
8+
)
59
from galileo.resources.api.run_scorer_settings import (
610
upsert_scorers_config_projects_project_id_runs_run_id_scorer_settings_post,
711
)
@@ -14,6 +18,7 @@
1418
ScorerTypeFilterOperator,
1519
ScorerTypes,
1620
)
21+
from galileo.resources.models.base_scorer_version_response import BaseScorerVersionResponse
1722
from galileo.resources.models.run_scorer_settings_patch_request import RunScorerSettingsPatchRequest
1823
from galileo.resources.models.run_scorer_settings_response import RunScorerSettingsResponse
1924
from galileo.resources.types import Unset
@@ -33,6 +38,21 @@ def list(self, types: list[ScorerTypes] = None) -> Union[Unset, list[ScorerRespo
3338
result = list_scorers_with_filters_scorers_list_post.sync(client=self.client, body=body)
3439
return result.scorers
3540

41+
def get_scorer_version(self, scorer_id: UUID, version: int) -> Union[Unset, BaseScorerVersionResponse]:
42+
"""
43+
Args:
44+
name: str
45+
Name of the scorer
46+
version: int
47+
Version of the scorer.
48+
Returns:
49+
Scorer response if found, otherwise None
50+
"""
51+
result = get_scorer_version_or_latest_scorers_scorer_id_version_get.sync(
52+
scorer_id=scorer_id, version=version, client=self.client
53+
)
54+
return result
55+
3656

3757
class ScorerSettings(BaseClientModel):
3858
def create(

tests/schemas/test_metrics.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
from pydantic import ValidationError
3+
4+
from galileo.schema.metrics import Metric
5+
from galileo_core.schemas.shared.scorers.scorer_name import ScorerName
6+
7+
8+
def test_metric_validator_preset_with_version():
9+
"""Test that creating a Metric with a preset name and version raises a ValidationError"""
10+
# Get a valid value from the ScorerName enum
11+
# First, get all the available enum values
12+
preset_names = [scorer.value for scorer in ScorerName]
13+
# Make sure there's at least one value
14+
assert preset_names, "No values found in ScorerName enum"
15+
preset_name = preset_names[0]
16+
17+
# Attempt to create a Metric with a preset name and a version
18+
with pytest.raises(ValidationError) as exc_info:
19+
Metric(name=preset_name, version=1)
20+
21+
# Verify the error message
22+
assert f"Galileo metric's '{preset_name}' do not support versioning at this time" in str(exc_info.value)
23+
24+
25+
def test_metric_validator_preset_no_version():
26+
"""Test that creating a Metric with a preset name and no version is valid"""
27+
# Get a valid value from the ScorerName enum
28+
preset_names = [scorer.value for scorer in ScorerName]
29+
assert preset_names, "No values found in ScorerName enum"
30+
preset_name = preset_names[0]
31+
32+
# Create a Metric with a preset name and no version
33+
metric = Metric(name=preset_name)
34+
35+
# Verify the metric is created correctly
36+
assert metric.name == preset_name
37+
assert metric.version is None
38+
39+
40+
def test_metric_validator_custom_with_version():
41+
"""Test that creating a Metric with a custom name and version is valid"""
42+
# Create a Metric with a custom name and a version
43+
metric = Metric(name="my_custom_metric", version=2)
44+
45+
# Verify the metric is created correctly
46+
assert metric.name == "my_custom_metric"
47+
assert metric.version == 2
48+
49+
50+
def test_metric_validator_custom_no_version():
51+
"""Test that creating a Metric with a custom name and no version is valid"""
52+
# Create a Metric with a custom name and no version
53+
metric = Metric(name="my_custom_metric")
54+
55+
# Verify the metric is created correctly
56+
assert metric.name == "my_custom_metric"
57+
assert metric.version is None

tests/test_experiments.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,58 @@ def test_create_scorer_configs(self, mock_scorer_settings, mock_scorers):
652652
# Test unknown metrics
653653
with pytest.raises(ValueError):
654654
Experiments.create_metric_configs("project_id", "experiment_id", ["unknown_metric"])
655+
656+
@patch("galileo.experiments.Scorers")
657+
@patch("galileo.experiments.ScorerSettings")
658+
def test_create_scorer_configs_with_metric_objects(self, mock_scorer_settings, mock_scorers):
659+
# Setup mock return values
660+
mock_scorers_instance = mock_scorers.return_value
661+
662+
# Create mock scorer responses
663+
mock_scorers = [
664+
ScorerResponse.from_dict({"id": "1", "name": "metric1", "scorer_type": "preset", "tags": ["test"]}),
665+
ScorerResponse.from_dict({"id": "2", "name": "metric2", "scorer_type": "preset", "tags": ["test"]}),
666+
ScorerResponse.from_dict({"id": "3", "name": "versionable_metric", "scorer_type": "llm", "tags": ["test"]}),
667+
]
668+
669+
mock_scorers_instance.list.return_value = mock_scorers
670+
671+
# Mock the get_scorer_version method
672+
mock_version_response = MagicMock()
673+
mock_version_response.to_dict.return_value = {"id": "version1", "version": 2}
674+
mock_scorers_instance.get_scorer_version.return_value = mock_version_response
675+
676+
from galileo.schema.metrics import Metric
677+
678+
# Test with Metric objects (without version)
679+
metric1 = Metric(name="metric1")
680+
metric2 = Metric(name="metric2")
681+
682+
scorers, local_scorers = Experiments.create_metric_configs("project_id", "experiment_id", [metric1, metric2])
683+
684+
assert len(scorers) == 2 # Should return two valid scorers
685+
assert len(local_scorers) == 0 # No local scorers
686+
687+
# Verify get_scorer_version was not called (since no version was specified)
688+
mock_scorers_instance.get_scorer_version.assert_not_called()
689+
690+
# Test with a Metric object with version
691+
versionable_metric = Metric(name="versionable_metric", version=2)
692+
693+
scorers, local_scorers = Experiments.create_metric_configs("project_id", "experiment_id", [versionable_metric])
694+
695+
assert len(scorers) == 1 # Should return one valid scorer
696+
assert len(local_scorers) == 0 # No local scorers
697+
698+
# Verify get_scorer_version was called with the correct parameters
699+
mock_scorers_instance.get_scorer_version.assert_called_once_with(scorer_id="3", version=2)
700+
701+
# Test mixed input types
702+
local_metric = LocalMetricConfig(name="length", scorer_fn=lambda x: len(x))
703+
704+
scorers, local_scorers = Experiments.create_metric_configs(
705+
"project_id", "experiment_id", ["metric1", local_metric, Metric(name="metric2")]
706+
)
707+
708+
assert len(scorers) == 2 # Should return two valid scorers
709+
assert len(local_scorers) == 1 # One local scorer

tests/test_scorers.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import uuid
12
from unittest.mock import ANY, Mock, patch
23

4+
from src.galileo.resources.models.base_scorer_version_response import BaseScorerVersionResponse
5+
36
from galileo.resources.models import (
47
ListScorersRequest,
58
ListScorersResponse,
@@ -67,3 +70,88 @@ def test_list_all_scorers_preset_filter(list_scorers_mock: Mock):
6770
filters=[ScorerTypeFilter(operator=ScorerTypeFilterOperator.EQ, value=ScorerTypes.LLM)]
6871
),
6972
)
73+
74+
75+
def create_mock_version_response():
76+
return BaseScorerVersionResponse.from_dict(
77+
{
78+
"id": "b8933a6d-7a65-4ce3-bfe4-b863109a0425",
79+
"version": 2,
80+
"model_name": "GPT-4o",
81+
"num_judges": 3,
82+
"created_at": "2025-03-28T18:54:02.848267+00:00",
83+
"updated_at": "2025-03-28T18:54:02.848269+00:00",
84+
"generated_scorer": {
85+
"id": "c7933a6d-7a65-4ce3-bfe4-b863109a0499",
86+
"name": "test_generated_scorer",
87+
"instructions": "Evaluate the response quality",
88+
"chain_poll_template": {
89+
"name": "quality_check",
90+
"prompt": "Rate the quality on a scale of 1-10",
91+
"template_type": "standard",
92+
},
93+
},
94+
"registered_scorer": None,
95+
}
96+
)
97+
98+
99+
class MockHTTPError(Exception):
100+
def __init__(self, status_code):
101+
self.status_code = status_code
102+
super().__init__(f"HTTP Error: {status_code}")
103+
104+
105+
@patch("galileo.scorers.get_scorer_version_or_latest_scorers_scorer_id_version_get")
106+
def test_get_scorer_version_success(get_scorer_version_mock: Mock):
107+
# Setup
108+
mock_response = create_mock_version_response()
109+
get_scorer_version_mock.sync.return_value = mock_response
110+
scorer_id = uuid.UUID("b8933a6d-7a65-4ce3-bfe4-b863109a0425")
111+
version = 2
112+
113+
# Execute
114+
result = Scorers().get_scorer_version(scorer_id=scorer_id, version=version)
115+
116+
# Verify
117+
assert result == mock_response
118+
get_scorer_version_mock.sync.assert_called_once_with(scorer_id=scorer_id, version=version, client=ANY)
119+
assert result.id == "b8933a6d-7a65-4ce3-bfe4-b863109a0425"
120+
assert result.version == 2
121+
# Access properties from additional_properties instead
122+
assert result.additional_properties["model_name"] == "GPT-4o"
123+
assert result.additional_properties["num_judges"] == 3
124+
# Access generated_scorer as a dictionary
125+
assert result.generated_scorer["name"] == "test_generated_scorer"
126+
assert result.registered_scorer is None
127+
128+
129+
@patch("galileo.scorers.get_scorer_version_or_latest_scorers_scorer_id_version_get")
130+
def test_get_scorer_version_not_found(get_scorer_version_mock: Mock):
131+
# Setup
132+
get_scorer_version_mock.sync.side_effect = MockHTTPError(status_code=404)
133+
scorer_id = uuid.UUID("b8933a6d-7a65-4ce3-bfe4-b863109a0425")
134+
version = 2
135+
136+
# Execute
137+
result = Scorers().get_scorer_version(scorer_id=scorer_id, version=version)
138+
139+
# Verify
140+
assert result is None
141+
get_scorer_version_mock.sync.assert_called_once_with(scorer_id=scorer_id, version=version, client=ANY)
142+
143+
144+
@patch("galileo.scorers.get_scorer_version_or_latest_scorers_scorer_id_version_get")
145+
def test_get_scorer_version_other_error(get_scorer_version_mock: Mock):
146+
# Setup
147+
error = MockHTTPError(status_code=500)
148+
get_scorer_version_mock.sync.side_effect = error
149+
scorer_id = uuid.UUID("b8933a6d-7a65-4ce3-bfe4-b863109a0425")
150+
version = 2
151+
152+
# Execute
153+
result = Scorers().get_scorer_version(scorer_id=scorer_id, version=version)
154+
155+
# Verify - it seems the implementation catches all exceptions, not just 404s
156+
assert result is None
157+
get_scorer_version_mock.sync.assert_called_once_with(scorer_id=scorer_id, version=version, client=ANY)

0 commit comments

Comments
 (0)