Skip to content

Commit d28a992

Browse files
committed
to_dict and to_pandas added to FOOOFResults
1 parent 856ab5c commit d28a992

File tree

3 files changed

+96
-38
lines changed

3 files changed

+96
-38
lines changed

fooof/core/funcs.py

+48-27
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,30 @@
77
- They are left available for easy swapping back in, if desired.
88
"""
99

10+
from inspect import isfunction
11+
1012
import numpy as np
1113

1214
from fooof.core.errors import InconsistentDataError
1315

1416
###################################################################################################
1517
###################################################################################################
1618

17-
def gaussian_function(xs, *params):
19+
def gaussian_function(xs, cf, pw, bw, *params):
1820
"""Gaussian fitting function.
1921
2022
Parameters
2123
----------
2224
xs : 1d array
2325
Input x-axis values.
26+
cf : float
27+
The center of the gaussian.
28+
pw : float
29+
The height of the gaussian.
30+
bw : float
31+
The width of the gaussian.
2432
*params : float
25-
Parameters that define gaussian function.
33+
Additional centers, heights, and widths.
2634
2735
Returns
2836
-------
@@ -32,16 +40,17 @@ def gaussian_function(xs, *params):
3240

3341
ys = np.zeros_like(xs)
3442

43+
params = [cf, pw, bw, *params]
44+
3545
for ii in range(0, len(params), 3):
3646

3747
ctr, hgt, wid = params[ii:ii+3]
38-
3948
ys = ys + hgt * np.exp(-(xs-ctr)**2 / (2*wid**2))
4049

4150
return ys
4251

4352

44-
def expo_function(xs, *params):
53+
def expo_function(xs, offset, knee, exp):
4554
"""Exponential fitting function, for fitting aperiodic component with a 'knee'.
4655
4756
NOTE: this function requires linear frequency (not log).
@@ -50,26 +59,32 @@ def expo_function(xs, *params):
5059
----------
5160
xs : 1d array
5261
Input x-axis values.
53-
*params : float
54-
Parameters (offset, knee, exp) that define Lorentzian function:
55-
y = 10^offset * (1/(knee + x^exp))
62+
offset : float
63+
The y-intercept of the fit.
64+
knee : float
65+
The bend in the fit.
66+
exp : float
67+
The exponential slope of the fit.
5668
5769
Returns
5870
-------
5971
ys : 1d array
6072
Output values for exponential function.
73+
74+
Notes
75+
-----
76+
Parameters (offset, knee, exp) that define Lorentzian function:
77+
y = 10^offset * (1/(knee + x^exp))
6178
"""
6279

6380
ys = np.zeros_like(xs)
6481

65-
offset, knee, exp = params
66-
6782
ys = ys + offset - np.log10(knee + xs**exp)
6883

6984
return ys
7085

7186

72-
def expo_nk_function(xs, *params):
87+
def expo_nk_function(xs, offset, exp):
7388
"""Exponential fitting function, for fitting aperiodic component without a 'knee'.
7489
7590
NOTE: this function requires linear frequency (not log).
@@ -78,34 +93,40 @@ def expo_nk_function(xs, *params):
7893
----------
7994
xs : 1d array
8095
Input x-axis values.
81-
*params : float
82-
Parameters (offset, exp) that define Lorentzian function:
83-
y = 10^off * (1/(x^exp))
96+
offset : float
97+
The y-intercept of the fit.
98+
exp : float
99+
The exponential slope of the fit.
84100
85101
Returns
86102
-------
87103
ys : 1d array
88104
Output values for exponential function, without a knee.
105+
106+
Notes
107+
-----
108+
Parameters (offset, exp) that define Lorentzian function:
109+
y = 10^off * (1/(x^exp))
89110
"""
90111

91112
ys = np.zeros_like(xs)
92113

93-
offset, exp = params
94-
95114
ys = ys + offset - np.log10(xs**exp)
96115

97116
return ys
98117

99118

100-
def linear_function(xs, *params):
119+
def linear_function(xs, offset, slope):
101120
"""Linear fitting function.
102121
103122
Parameters
104123
----------
105124
xs : 1d array
106125
Input x-axis values.
107-
*params : float
108-
Parameters that define linear function.
126+
offset : float
127+
The y-intercept of the fit.
128+
slope : float
129+
The slope of the fit.
109130
110131
Returns
111132
-------
@@ -115,22 +136,24 @@ def linear_function(xs, *params):
115136

116137
ys = np.zeros_like(xs)
117138

118-
offset, slope = params
119-
120139
ys = ys + offset + (xs*slope)
121140

122141
return ys
123142

124143

125-
def quadratic_function(xs, *params):
144+
def quadratic_function(xs, offset, slope, curve):
126145
"""Quadratic fitting function.
127146
128147
Parameters
129148
----------
130149
xs : 1d array
131150
Input x-axis values.
132-
*params : float
133-
Parameters that define quadratic function.
151+
offset : float
152+
The y-intercept of the fit.
153+
slope : float
154+
The slope of the fit.
155+
curve : float
156+
The curve of the fit.
134157
135158
Returns
136159
-------
@@ -140,8 +163,6 @@ def quadratic_function(xs, *params):
140163

141164
ys = np.zeros_like(xs)
142165

143-
offset, slope, curve = params
144-
145166
ys = ys + offset + (xs*slope) + ((xs**2)*curve)
146167

147168
return ys
@@ -167,7 +188,7 @@ def get_pe_func(periodic_mode):
167188
168189
"""
169190

170-
if isinstance(periodic_mode, function):
191+
if isfunction(periodic_mode):
171192
pe_func = periodic_mode
172193
elif periodic_mode == 'gaussian':
173194
pe_func = gaussian_function
@@ -196,7 +217,7 @@ def get_ap_func(aperiodic_mode):
196217
If the specified aperiodic mode label is not understood.
197218
"""
198219

199-
if isinstance(aperiodic_mode, function):
220+
if isfunction(aperiodic_mode):
200221
ap_func = aperiodic_mode
201222
elif aperiodic_mode == 'fixed':
202223
ap_func = expo_nk_function

fooof/data/data.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,30 @@ class FOOOFResults(namedtuple('FOOOFResults', ['aperiodic_params', 'peak_params'
8484
-----
8585
This object is a data object, based on a NamedTuple, with immutable data attributes.
8686
"""
87+
8788
__slots__ = ()
8889

90+
def to_dict(self):
91+
92+
# Combined peak, aperiodic, and goodness of fit params
93+
results_dict = OrderedDict()
94+
results_dict.update(self.peak_params.to_dict())
95+
results_dict.update(self.aperiodic_params.to_dict())
96+
results_dict.update(OrderedDict(r_squared=self.r_squared, error=self.error))
97+
98+
return results_dict
99+
100+
def to_pandas(self):
101+
102+
if pd is None:
103+
raise ValueError("Pandas is not installed.")
104+
105+
results_dict = self.to_dict()
106+
107+
results_df = pd.DataFrame(results_dict)
108+
109+
return results_df
110+
89111

90112
class SimParams(namedtuple('SimParams', ['aperiodic_params', 'periodic_params', 'nlv'])):
91113
"""Parameters that define a simulated power spectrum.
@@ -108,29 +130,29 @@ class SimParams(namedtuple('SimParams', ['aperiodic_params', 'periodic_params',
108130

109131
class FitParams(np.ndarray):
110132

111-
def __new__(cls, results, labels):
133+
def __new__(cls, params, labels):
112134

113-
return np.asarray(results).view(cls)
135+
return np.asarray(params).view(cls)
114136

115-
def __init__(self, results, labels):
137+
def __init__(self, params, labels):
116138

117-
self.results = results
139+
self.params = params
118140
self.labels = labels
119141

120142
def to_dict(self):
121143

122-
results = OrderedDict((k, v) for k, v in \
123-
zip(self.labels, self.results.transpose()))
144+
params = OrderedDict((k, v) for k, v in \
145+
zip(self.labels, self.params.transpose()))
124146

125-
return results
147+
return params
126148

127149
def to_pandas(self):
128150

129151
if pd is None:
130152
raise ValueError("Pandas is not installed.")
131153

132-
results = self.to_dict()
154+
params = self.to_dict()
133155

134-
results = pd.DataFrame(results)
156+
params = pd.DataFrame(params)
135157

136-
return results
158+
return params

fooof/objs/fit.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656

5757
import warnings
5858
from copy import deepcopy
59+
from inspect import signature
5960

6061
import numpy as np
6162
from numpy.linalg import LinAlgError
@@ -77,7 +78,7 @@
7778
from fooof.plts.style import style_spectrum_plot
7879
from fooof.utils.data import trim_spectrum
7980
from fooof.utils.params import compute_gauss_std
80-
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData
81+
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData, FitParams
8182
from fooof.sim.gen import gen_freqs, gen_aperiodic, gen_periodic, gen_model
8283

8384
###################################################################################################
@@ -484,6 +485,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
484485
# Convert gaussian definitions to peak parameters
485486
self.peak_params_ = self._create_peak_params(self.gaussian_params_)
486487

488+
# Create a flexible datatype object for ap/pe parameters
489+
pe_labels = list(signature(self._pe_func).parameters.keys())[1:]
490+
ap_labels = list(signature(self._ap_func).parameters.keys())[1:]
491+
492+
self.peak_params_ = FitParams(self.peak_params_, pe_labels)
493+
self.gaussian_params_ = FitParams(self.gaussian_params_, pe_labels)
494+
self.aperiodic_params_ = FitParams(self.aperiodic_params_, ap_labels)
495+
487496
# Calculate R^2 and error of the model fit
488497
self._calc_r_squared()
489498
self._calc_error()
@@ -636,6 +645,12 @@ def get_results(self):
636645
return FOOOFResults(**{key.strip('_') : getattr(self, key) \
637646
for key in OBJ_DESC['results']})
638647

648+
def get_results_dict(self):
649+
"""Return model fit parameters and goodness of fit metrics as an ordered dictionary."""
650+
651+
results = self.get_results()
652+
653+
#pe_params =
639654

640655
@copy_doc_func_to_method(plot_fm)
641656
def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,

0 commit comments

Comments
 (0)