Skip to content

Commit 869001b

Browse files
bpkrothpre-commit-ci[bot]motus
authored
Refactor Status.parse (#982)
# Pull Request ## Title Small refactor to `Status.parse` ______________________________________________________________________ ## Description - Rename `Status.from_dict` -> `Status.parse` - Allow `Status.parse` to accept a `Status` as an argument and return it as is. (complete convenience for some python comprehension code elsewhere) - Add a test for that. ______________________________________________________________________ ## Type of Change - 🔄 Refactor ______________________________________________________________________ ## Testing Small new CI check and existing tests. ______________________________________________________________________ ## Additional Notes (optional) For convenience in another PR - #980 ______________________________________________________________________ --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sergiy Matusevych <sergiym@microsoft.com>
1 parent 3111f9b commit 869001b

File tree

4 files changed

+50
-20
lines changed

4 files changed

+50
-20
lines changed

mlos_bench/mlos_bench/environments/status.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,37 @@ class Status(enum.Enum):
2424
TIMED_OUT = 7
2525

2626
@staticmethod
27-
def from_str(status_str: Any) -> "Status":
28-
"""Convert a string to a Status enum."""
29-
if not isinstance(status_str, str):
30-
_LOG.warning("Expected type %s for status: %s", type(status_str), status_str)
31-
status_str = str(status_str)
32-
if status_str.isdigit():
27+
def parse(status: Any) -> "Status":
28+
"""
29+
Convert the input to a Status enum.
30+
31+
Parameters
32+
----------
33+
status : Any
34+
The status to parse. This can be a string (or string convertible),
35+
int, or Status enum.
36+
37+
Returns
38+
-------
39+
Status
40+
The corresponding Status enum value or else UNKNOWN if the input is not
41+
recognized.
42+
"""
43+
if isinstance(status, Status):
44+
return status
45+
if not isinstance(status, str):
46+
_LOG.warning("Expected type %s for status: %s", type(status), status)
47+
status = str(status)
48+
if status.isdigit():
3349
try:
34-
return Status(int(status_str))
50+
return Status(int(status))
3551
except ValueError:
36-
_LOG.warning("Unknown status: %d", int(status_str))
52+
_LOG.warning("Unknown status: %d", int(status))
3753
try:
38-
status_str = status_str.upper().strip()
39-
return Status[status_str]
54+
status = status.upper().strip()
55+
return Status[status]
4056
except KeyError:
41-
_LOG.warning("Unknown status: %s", status_str)
57+
_LOG.warning("Unknown status: %s", status)
4258
return Status.UNKNOWN
4359

4460
def is_good(self) -> bool:
@@ -113,4 +129,15 @@ def is_timed_out(self) -> bool:
113129
Status.TIMED_OUT,
114130
}
115131
)
116-
"""The set of completed statuses."""
132+
"""
133+
The set of completed statuses.
134+
135+
Includes all statuses that indicate the trial or experiment has finished, either
136+
successfully or not.
137+
This set is used to determine if a trial or experiment has reached a final state.
138+
This includes:
139+
- :py:attr:`.Status.SUCCEEDED`: The trial or experiment completed successfully.
140+
- :py:attr:`.Status.CANCELED`: The trial or experiment was canceled.
141+
- :py:attr:`.Status.FAILED`: The trial or experiment failed.
142+
- :py:attr:`.Status.TIMED_OUT`: The trial or experiment timed out.
143+
"""

mlos_bench/mlos_bench/storage/sql/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_trials(
9595
config_id=trial.config_id,
9696
ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
9797
ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
98-
status=Status.from_str(trial.status),
98+
status=Status.parse(trial.status),
9999
trial_runner_id=trial.trial_runner_id,
100100
)
101101
for trial in trials.fetchall()

mlos_bench/mlos_bench/storage/sql/experiment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def load(
188188
status: list[Status] = []
189189

190190
for trial in cur_trials.fetchall():
191-
stat = Status.from_str(trial.status)
191+
stat = Status.parse(trial.status)
192192
status.append(stat)
193193
trial_ids.append(trial.trial_id)
194194
configs.append(
@@ -272,7 +272,7 @@ def get_trial_by_id(
272272
config_id=trial.config_id,
273273
trial_runner_id=trial.trial_runner_id,
274274
opt_targets=self._opt_targets,
275-
status=Status.from_str(trial.status),
275+
status=Status.parse(trial.status),
276276
restoring=True,
277277
config=config,
278278
)
@@ -330,7 +330,7 @@ def pending_trials(
330330
config_id=trial.config_id,
331331
trial_runner_id=trial.trial_runner_id,
332332
opt_targets=self._opt_targets,
333-
status=Status.from_str(trial.status),
333+
status=Status.parse(trial.status),
334334
restoring=True,
335335
config=config,
336336
)

mlos_bench/mlos_bench/tests/environments/test_status.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,19 @@ def test_status_from_str_valid(input_str: str, expected_status: Status) -> None:
5151
Expected Status enum value.
5252
"""
5353
assert (
54-
Status.from_str(input_str) == expected_status
54+
Status.parse(input_str) == expected_status
5555
), f"Expected {expected_status} for input: {input_str}"
5656
# Check lowercase representation
5757
assert (
58-
Status.from_str(input_str.lower()) == expected_status
58+
Status.parse(input_str.lower()) == expected_status
5959
), f"Expected {expected_status} for input: {input_str.lower()}"
60+
assert (
61+
Status.parse(expected_status) == expected_status
62+
), f"Expected {expected_status} for input: {expected_status}"
6063
if input_str.isdigit():
6164
# Also test the numeric representation
6265
assert (
63-
Status.from_str(int(input_str)) == expected_status
66+
Status.parse(int(input_str)) == expected_status
6467
), f"Expected {expected_status} for input: {int(input_str)}"
6568

6669

@@ -83,7 +86,7 @@ def test_status_from_str_invalid(invalid_input: Any) -> None:
8386
input.
8487
"""
8588
assert (
86-
Status.from_str(invalid_input) == Status.UNKNOWN
89+
Status.parse(invalid_input) == Status.UNKNOWN
8790
), f"Expected Status.UNKNOWN for invalid input: {invalid_input}"
8891

8992

0 commit comments

Comments
 (0)