2
2
import numpy as np # type: ignore
3
3
import iklayout # type: ignore
4
4
import matplotlib .pyplot as plt # type: ignore
5
- import plotly .graph_objects as go # type: ignore
5
+ import plotly .graph_objects as go # type: ignore
6
6
from typing import List , Optional , Tuple , Dict , Set
7
7
8
8
from . import Parameter , StatementDictionary , StatementValidationDictionary , StatementValidation , Computation
@@ -56,9 +56,7 @@ def plot_constraints(
56
56
labels: List of labels for each constraint value.
57
57
"""
58
58
59
- constraints_labels = constraints_labels or [
60
- f"Constraint { i } " for i in range (len (constraints [0 ]))
61
- ]
59
+ constraints_labels = constraints_labels or [f"Constraint { i } " for i in range (len (constraints [0 ]))]
62
60
iterations = iterations or list (range (len (constraints [0 ])))
63
61
64
62
plt .clf ()
@@ -92,13 +90,9 @@ def plot_single_spectrum(
92
90
plt .ylabel ("Losses" )
93
91
plt .plot (wavelengths , spectrum )
94
92
for x_val in vlines :
95
- plt .axvline (
96
- x = x_val , color = "red" , linestyle = "--" , label = f"Wavelength (x={ x_val } )"
97
- ) # Add vertical line
93
+ plt .axvline (x = x_val , color = "red" , linestyle = "--" , label = f"Wavelength (x={ x_val } )" ) # Add vertical line
98
94
for y_val in hlines :
99
- plt .axhline (
100
- y = y_val , color = "red" , linestyle = "--" , label = f"Transmission (y={ y_val } )"
101
- ) # Add vertical line
95
+ plt .axhline (y = y_val , color = "red" , linestyle = "--" , label = f"Transmission (y={ y_val } )" ) # Add vertical line
102
96
return plt .gcf ()
103
97
104
98
@@ -109,7 +103,7 @@ def plot_interactive_spectra(
109
103
vlines : Optional [List [float ]] = None ,
110
104
hlines : Optional [List [float ]] = None ,
111
105
):
112
- """"
106
+ """ "
113
107
Creates an interactive plot of spectra with a slider to select different indices.
114
108
Parameters:
115
109
-----------
@@ -131,7 +125,7 @@ def plot_interactive_spectra(
131
125
vlines = []
132
126
if hlines is None :
133
127
hlines = []
134
-
128
+
135
129
# Adjust y-axis range
136
130
all_vals = [val for spec in spectra for iteration in spec for val in iteration ]
137
131
y_min = min (all_vals )
@@ -143,49 +137,28 @@ def plot_interactive_spectra(
143
137
# Create hlines and vlines
144
138
shapes = []
145
139
for xv in vlines :
146
- shapes .append (dict (
147
- type = "line" ,
148
- xref = "x" , x0 = xv , x1 = xv ,
149
- yref = "paper" , y0 = 0 , y1 = 1 ,
150
- line = dict (color = "red" , dash = "dash" )
151
- ))
140
+ shapes .append (
141
+ dict (type = "line" , xref = "x" , x0 = xv , x1 = xv , yref = "paper" , y0 = 0 , y1 = 1 , line = dict (color = "red" , dash = "dash" ))
142
+ )
152
143
for yh in hlines :
153
- shapes .append (dict (
154
- type = "line" ,
155
- xref = "paper" , x0 = 0 , x1 = 1 ,
156
- yref = "y" , y0 = yh , y1 = yh ,
157
- line = dict (color = "red" , dash = "dash" )
158
- ))
159
-
160
-
144
+ shapes .append (
145
+ dict (type = "line" , xref = "paper" , x0 = 0 , x1 = 1 , yref = "y" , y0 = yh , y1 = yh , line = dict (color = "red" , dash = "dash" ))
146
+ )
147
+
161
148
# Create frames for each index
162
149
slider_index = list (range (len (spectra [0 ])))
163
150
fig = go .Figure ()
164
151
165
152
# Build initial figure for immediate display
166
153
init_idx = slider_index [0 ]
167
154
for i , spec in enumerate (spectra ):
168
- fig .add_trace (
169
- go .Scatter (
170
- x = wavelengths ,
171
- y = spec [init_idx ],
172
- mode = "lines" ,
173
- name = spectrum_labels [i ]
174
- )
175
- )
155
+ fig .add_trace (go .Scatter (x = wavelengths , y = spec [init_idx ], mode = "lines" , name = spectrum_labels [i ]))
176
156
# Build frames for animation
177
157
frames = []
178
158
for idx in slider_index :
179
159
frame_data = []
180
160
for i , spec in enumerate (spectra ):
181
- frame_data .append (
182
- go .Scatter (
183
- x = wavelengths ,
184
- y = spec [idx ],
185
- mode = "lines" ,
186
- name = spectrum_labels [i ]
187
- )
188
- )
161
+ frame_data .append (go .Scatter (x = wavelengths , y = spec [idx ], mode = "lines" , name = spectrum_labels [i ]))
189
162
frames .append (
190
163
go .Frame (
191
164
data = frame_data ,
@@ -195,30 +168,22 @@ def plot_interactive_spectra(
195
168
196
169
fig .frames = frames
197
170
198
-
199
171
# Create transition steps
200
172
steps = []
201
173
for idx in slider_index :
202
- steps .append (dict (
203
- method = "animate" ,
204
- args = [
205
- [str (idx )],
206
- {
207
- "mode" : "immediate" ,
208
- "frame" : {"duration" : 0 , "redraw" : True },
209
- "transition" : {"duration" : 0 }
210
- }
211
- ],
212
- label = str (idx ),
213
- ))
174
+ steps .append (
175
+ dict (
176
+ method = "animate" ,
177
+ args = [
178
+ [str (idx )],
179
+ {"mode" : "immediate" , "frame" : {"duration" : 0 , "redraw" : True }, "transition" : {"duration" : 0 }},
180
+ ],
181
+ label = str (idx ),
182
+ )
183
+ )
214
184
215
185
# Create the slider
216
- sliders = [dict (
217
- active = 0 ,
218
- currentvalue = {"prefix" : "Index: " },
219
- pad = {"t" : 50 },
220
- steps = steps
221
- )]
186
+ sliders = [dict (active = 0 , currentvalue = {"prefix" : "Index: " }, pad = {"t" : 50 }, steps = steps )]
222
187
223
188
# Create the layout
224
189
fig .update_layout (
@@ -253,25 +218,32 @@ def plot_parameter_history(parameters: List[Parameter], parameter_history: List[
253
218
plt .xlabel ("Iterations" )
254
219
plt .ylabel (param .path )
255
220
split_param = param .path .split ("," )
256
- plt .plot (
257
- [
258
- parameter_history [i ][split_param [0 ]][split_param [1 ]]
259
- for i in range (len (parameter_history ))
260
- ]
261
- )
221
+ if "," in param .path :
222
+ split_param = param .path .split ("," )
223
+ plt .plot ([parameter_history [i ][split_param [0 ]][split_param [1 ]] for i in range (len (parameter_history ))])
224
+ else :
225
+ plt .plot ([parameter_history [i ][param .path ] for i in range (len (parameter_history ))])
262
226
plt .show ()
263
227
264
228
265
- def print_statements (statements : StatementDictionary , validation : Optional [StatementValidationDictionary ] = None , only_formalized : bool = False ):
229
+ def print_statements (
230
+ statements : StatementDictionary ,
231
+ validation : Optional [StatementValidationDictionary ] = None ,
232
+ only_formalized : bool = False ,
233
+ ):
266
234
"""
267
235
Print a list of statements in nice readable format.
268
236
"""
269
237
270
238
validation = StatementValidationDictionary (
271
- cost_functions = (validation .cost_functions if validation is not None else None ) or [StatementValidation ()]* len (statements .cost_functions or []),
272
- parameter_constraints = (validation .parameter_constraints if validation is not None else None ) or [StatementValidation ()]* len (statements .parameter_constraints or []),
273
- structure_constraints = (validation .structure_constraints if validation is not None else None ) or [StatementValidation ()]* len (statements .structure_constraints or []),
274
- unformalizable_statements = (validation .unformalizable_statements if validation is not None else None ) or [StatementValidation ()]* len (statements .unformalizable_statements or [])
239
+ cost_functions = (validation .cost_functions if validation is not None else None )
240
+ or [StatementValidation ()] * len (statements .cost_functions or []),
241
+ parameter_constraints = (validation .parameter_constraints if validation is not None else None )
242
+ or [StatementValidation ()] * len (statements .parameter_constraints or []),
243
+ structure_constraints = (validation .structure_constraints if validation is not None else None )
244
+ or [StatementValidation ()] * len (statements .structure_constraints or []),
245
+ unformalizable_statements = (validation .unformalizable_statements if validation is not None else None )
246
+ or [StatementValidation ()] * len (statements .unformalizable_statements or []),
275
247
)
276
248
277
249
if len (validation .cost_functions or []) != len (statements .cost_functions or []):
@@ -299,8 +271,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
299
271
if computation is not None :
300
272
args_str = ", " .join (
301
273
[
302
- f"{ argname } ="
303
- + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
274
+ f"{ argname } =" + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
304
275
for argname , argvalue in computation .arguments .items ()
305
276
]
306
277
)
@@ -326,8 +297,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
326
297
if computation is not None :
327
298
args_str = ", " .join (
328
299
[
329
- f"{ argname } ="
330
- + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
300
+ f"{ argname } =" + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
331
301
for argname , argvalue in computation .arguments .items ()
332
302
]
333
303
)
@@ -382,9 +352,7 @@ def _str_units_to_float(str_units: str) -> float:
382
352
return float (numeric_value * unit_conversions [unit ])
383
353
384
354
385
- def get_wavelengths_to_plot (
386
- statements : StatementDictionary , num_samples : int = 100
387
- ) -> Tuple [List [float ], List [float ]]:
355
+ def get_wavelengths_to_plot (statements : StatementDictionary , num_samples : int = 100 ) -> Tuple [List [float ], List [float ]]:
388
356
"""
389
357
Get the wavelengths to plot based on the statements.
390
358
@@ -401,10 +369,16 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
401
369
continue
402
370
if "wavelengths" in comp .arguments :
403
371
vlines = vlines | {
404
- _str_units_to_float (wl ) for wl in (comp .arguments ["wavelengths" ] if isinstance (comp .arguments ["wavelengths" ], list ) else []) if isinstance (wl , str )
372
+ _str_units_to_float (wl )
373
+ for wl in (comp .arguments ["wavelengths" ] if isinstance (comp .arguments ["wavelengths" ], list ) else [])
374
+ if isinstance (wl , str )
405
375
}
406
376
if "wavelength_range" in comp .arguments :
407
- if isinstance (comp .arguments ["wavelength_range" ], list ) and len (comp .arguments ["wavelength_range" ]) == 2 and all (isinstance (wl , str ) for wl in comp .arguments ["wavelength_range" ]):
377
+ if (
378
+ isinstance (comp .arguments ["wavelength_range" ], list )
379
+ and len (comp .arguments ["wavelength_range" ]) == 2
380
+ and all (isinstance (wl , str ) for wl in comp .arguments ["wavelength_range" ])
381
+ ):
408
382
min_wl = min (min_wl , _str_units_to_float (comp .arguments ["wavelength_range" ][0 ]))
409
383
max_wl = max (max_wl , _str_units_to_float (comp .arguments ["wavelength_range" ][1 ]))
410
384
return min_wl , max_wl , vlines
0 commit comments