Skip to content

Commit 0c17ea3

Browse files
committed
fix parameter history
1 parent 784c079 commit 0c17ea3

File tree

1 file changed

+55
-81
lines changed

1 file changed

+55
-81
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 55 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np # type: ignore
33
import iklayout # type: ignore
44
import matplotlib.pyplot as plt # type: ignore
5-
import plotly.graph_objects as go # type: ignore
5+
import plotly.graph_objects as go # type: ignore
66
from typing import List, Optional, Tuple, Dict, Set
77

88
from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation, Computation
@@ -56,9 +56,7 @@ def plot_constraints(
5656
labels: List of labels for each constraint value.
5757
"""
5858

59-
constraints_labels = constraints_labels or [
60-
f"Constraint {i}" for i in range(len(constraints[0]))
61-
]
59+
constraints_labels = constraints_labels or [f"Constraint {i}" for i in range(len(constraints[0]))]
6260
iterations = iterations or list(range(len(constraints[0])))
6361

6462
plt.clf()
@@ -92,13 +90,9 @@ def plot_single_spectrum(
9290
plt.ylabel("Losses")
9391
plt.plot(wavelengths, spectrum)
9492
for x_val in vlines:
95-
plt.axvline(
96-
x=x_val, color="red", linestyle="--", label=f"Wavelength (x={x_val})"
97-
) # Add vertical line
93+
plt.axvline(x=x_val, color="red", linestyle="--", label=f"Wavelength (x={x_val})") # Add vertical line
9894
for y_val in hlines:
99-
plt.axhline(
100-
y=y_val, color="red", linestyle="--", label=f"Transmission (y={y_val})"
101-
) # Add vertical line
95+
plt.axhline(y=y_val, color="red", linestyle="--", label=f"Transmission (y={y_val})") # Add vertical line
10296
return plt.gcf()
10397

10498

@@ -109,7 +103,7 @@ def plot_interactive_spectra(
109103
vlines: Optional[List[float]] = None,
110104
hlines: Optional[List[float]] = None,
111105
):
112-
""""
106+
""" "
113107
Creates an interactive plot of spectra with a slider to select different indices.
114108
Parameters:
115109
-----------
@@ -131,7 +125,7 @@ def plot_interactive_spectra(
131125
vlines = []
132126
if hlines is None:
133127
hlines = []
134-
128+
135129
# Adjust y-axis range
136130
all_vals = [val for spec in spectra for iteration in spec for val in iteration]
137131
y_min = min(all_vals)
@@ -143,49 +137,28 @@ def plot_interactive_spectra(
143137
# Create hlines and vlines
144138
shapes = []
145139
for xv in vlines:
146-
shapes.append(dict(
147-
type="line",
148-
xref="x", x0=xv, x1=xv,
149-
yref="paper", y0=0, y1=1,
150-
line=dict(color="red", dash="dash")
151-
))
140+
shapes.append(
141+
dict(type="line", xref="x", x0=xv, x1=xv, yref="paper", y0=0, y1=1, line=dict(color="red", dash="dash"))
142+
)
152143
for yh in hlines:
153-
shapes.append(dict(
154-
type="line",
155-
xref="paper", x0=0, x1=1,
156-
yref="y", y0=yh, y1=yh,
157-
line=dict(color="red", dash="dash")
158-
))
159-
160-
144+
shapes.append(
145+
dict(type="line", xref="paper", x0=0, x1=1, yref="y", y0=yh, y1=yh, line=dict(color="red", dash="dash"))
146+
)
147+
161148
# Create frames for each index
162149
slider_index = list(range(len(spectra[0])))
163150
fig = go.Figure()
164151

165152
# Build initial figure for immediate display
166153
init_idx = slider_index[0]
167154
for i, spec in enumerate(spectra):
168-
fig.add_trace(
169-
go.Scatter(
170-
x=wavelengths,
171-
y=spec[init_idx],
172-
mode="lines",
173-
name=spectrum_labels[i]
174-
)
175-
)
155+
fig.add_trace(go.Scatter(x=wavelengths, y=spec[init_idx], mode="lines", name=spectrum_labels[i]))
176156
# Build frames for animation
177157
frames = []
178158
for idx in slider_index:
179159
frame_data = []
180160
for i, spec in enumerate(spectra):
181-
frame_data.append(
182-
go.Scatter(
183-
x=wavelengths,
184-
y=spec[idx],
185-
mode="lines",
186-
name=spectrum_labels[i]
187-
)
188-
)
161+
frame_data.append(go.Scatter(x=wavelengths, y=spec[idx], mode="lines", name=spectrum_labels[i]))
189162
frames.append(
190163
go.Frame(
191164
data=frame_data,
@@ -195,30 +168,22 @@ def plot_interactive_spectra(
195168

196169
fig.frames = frames
197170

198-
199171
# Create transition steps
200172
steps = []
201173
for idx in slider_index:
202-
steps.append(dict(
203-
method="animate",
204-
args=[
205-
[str(idx)],
206-
{
207-
"mode": "immediate",
208-
"frame": {"duration": 0, "redraw": True},
209-
"transition": {"duration": 0}
210-
}
211-
],
212-
label=str(idx),
213-
))
174+
steps.append(
175+
dict(
176+
method="animate",
177+
args=[
178+
[str(idx)],
179+
{"mode": "immediate", "frame": {"duration": 0, "redraw": True}, "transition": {"duration": 0}},
180+
],
181+
label=str(idx),
182+
)
183+
)
214184

215185
# Create the slider
216-
sliders = [dict(
217-
active=0,
218-
currentvalue={"prefix": "Index: "},
219-
pad={"t": 50},
220-
steps=steps
221-
)]
186+
sliders = [dict(active=0, currentvalue={"prefix": "Index: "}, pad={"t": 50}, steps=steps)]
222187

223188
# Create the layout
224189
fig.update_layout(
@@ -253,25 +218,32 @@ def plot_parameter_history(parameters: List[Parameter], parameter_history: List[
253218
plt.xlabel("Iterations")
254219
plt.ylabel(param.path)
255220
split_param = param.path.split(",")
256-
plt.plot(
257-
[
258-
parameter_history[i][split_param[0]][split_param[1]]
259-
for i in range(len(parameter_history))
260-
]
261-
)
221+
if "," in param.path:
222+
split_param = param.path.split(",")
223+
plt.plot([parameter_history[i][split_param[0]][split_param[1]] for i in range(len(parameter_history))])
224+
else:
225+
plt.plot([parameter_history[i][param.path] for i in range(len(parameter_history))])
262226
plt.show()
263227

264228

265-
def print_statements(statements: StatementDictionary, validation: Optional[StatementValidationDictionary] = None, only_formalized: bool = False):
229+
def print_statements(
230+
statements: StatementDictionary,
231+
validation: Optional[StatementValidationDictionary] = None,
232+
only_formalized: bool = False,
233+
):
266234
"""
267235
Print a list of statements in nice readable format.
268236
"""
269237

270238
validation = StatementValidationDictionary(
271-
cost_functions=(validation.cost_functions if validation is not None else None) or [StatementValidation()]*len(statements.cost_functions or []),
272-
parameter_constraints=(validation.parameter_constraints if validation is not None else None) or [StatementValidation()]*len(statements.parameter_constraints or []),
273-
structure_constraints=(validation.structure_constraints if validation is not None else None) or [StatementValidation()]*len(statements.structure_constraints or []),
274-
unformalizable_statements=(validation.unformalizable_statements if validation is not None else None) or [StatementValidation()]*len(statements.unformalizable_statements or [])
239+
cost_functions=(validation.cost_functions if validation is not None else None)
240+
or [StatementValidation()] * len(statements.cost_functions or []),
241+
parameter_constraints=(validation.parameter_constraints if validation is not None else None)
242+
or [StatementValidation()] * len(statements.parameter_constraints or []),
243+
structure_constraints=(validation.structure_constraints if validation is not None else None)
244+
or [StatementValidation()] * len(statements.structure_constraints or []),
245+
unformalizable_statements=(validation.unformalizable_statements if validation is not None else None)
246+
or [StatementValidation()] * len(statements.unformalizable_statements or []),
275247
)
276248

277249
if len(validation.cost_functions or []) != len(statements.cost_functions or []):
@@ -299,8 +271,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
299271
if computation is not None:
300272
args_str = ", ".join(
301273
[
302-
f"{argname}="
303-
+ (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
274+
f"{argname}=" + (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
304275
for argname, argvalue in computation.arguments.items()
305276
]
306277
)
@@ -326,8 +297,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
326297
if computation is not None:
327298
args_str = ", ".join(
328299
[
329-
f"{argname}="
330-
+ (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
300+
f"{argname}=" + (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
331301
for argname, argvalue in computation.arguments.items()
332302
]
333303
)
@@ -382,9 +352,7 @@ def _str_units_to_float(str_units: str) -> float:
382352
return float(numeric_value * unit_conversions[unit])
383353

384354

385-
def get_wavelengths_to_plot(
386-
statements: StatementDictionary, num_samples: int = 100
387-
) -> Tuple[List[float], List[float]]:
355+
def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int = 100) -> Tuple[List[float], List[float]]:
388356
"""
389357
Get the wavelengths to plot based on the statements.
390358
@@ -401,10 +369,16 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
401369
continue
402370
if "wavelengths" in comp.arguments:
403371
vlines = vlines | {
404-
_str_units_to_float(wl) for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else []) if isinstance(wl, str)
372+
_str_units_to_float(wl)
373+
for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else [])
374+
if isinstance(wl, str)
405375
}
406376
if "wavelength_range" in comp.arguments:
407-
if isinstance(comp.arguments["wavelength_range"], list) and len(comp.arguments["wavelength_range"]) == 2 and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"]):
377+
if (
378+
isinstance(comp.arguments["wavelength_range"], list)
379+
and len(comp.arguments["wavelength_range"]) == 2
380+
and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"])
381+
):
408382
min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0]))
409383
max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1]))
410384
return min_wl, max_wl, vlines

0 commit comments

Comments
 (0)