Skip to content

Commit 0754d73

Browse files
authored
refactor models (#636)
Signed-off-by: Andrey Parfenov <a1994ndrey@gmail.com>
1 parent aaa651b commit 0754d73

File tree

8 files changed

+21
-14
lines changed

8 files changed

+21
-14
lines changed

cpp_package/src/inc/data_filter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ class DataFilter
193193
* @return oxygen level
194194
*/
195195
static double get_oxygen_level (double *ppg_ir, double *ppg_red, int data_len,
196-
int sampling_rate, double coef1 = 0.0, double coef2 = -37.663, double coef3 = 114.91);
196+
int sampling_rate, double coef1 = 1.5958422, double coef2 = -34.6596622,
197+
double coef3 = 112.6898759);
197198
/**
198199
* calculate heart rate
199200
* @param ppg_ir input 1d array

csharp_package/brainflow/brainflow/data_filter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ public static double get_railed_percentage (double[] data, int gain)
306306
/// <param name="coef2">approximation coef</param>
307307
/// /// <param name="coef3">intercept for approximation</param>
308308
/// <returns>stddev</returns>
309-
public static double get_oxygen_level (double[] ppg_ir, double[] ppg_red, int sampling_rate, double coef1 = 0.0, double coef2 = -37.663, double coef3 = 114.91)
309+
public static double get_oxygen_level (double[] ppg_ir, double[] ppg_red, int sampling_rate, double coef1 = 1.5958422, double coef2 = -34.6596622, double coef3 = 112.6898759)
310310
{
311311
if (ppg_ir.Length != ppg_red.Length)
312312
{

java_package/brainflow/src/main/java/brainflow/DataFilter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ public static double get_oxygen_level (double[] ppg_ir, double[] ppg_red, int sa
251251
*/
252252
public static double get_oxygen_level (double[] ppg_ir, double[] ppg_red, int sampling_rate) throws BrainFlowError
253253
{
254-
return get_oxygen_level (ppg_ir, ppg_red, sampling_rate, 0.0, -37.663, 114.91);
254+
return get_oxygen_level (ppg_ir, ppg_red, sampling_rate, 1.5958422, -34.6596622, 112.6898759);
255255
}
256256

257257
/**

julia_package/brainflow/src/data_filter.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ end
253253
return output[1]
254254
end
255255

256-
@brainflow_rethrow function get_oxygen_level(ppg_ir, ppg_red, sampling_rate::Integer, coef1=0.0, coef2=-37.663, coef3=114.91)
256+
@brainflow_rethrow function get_oxygen_level(ppg_ir, ppg_red, sampling_rate::Integer, coef1=1.5958422, coef2=-34.6596622, coef3=112.6898759)
257257
if length(ppg_ir) != length(ppg_red)
258258
throw(BrainFlowError(string("invalid size", INVALID_ARGUMENTS_ERROR), Integer(INVALID_ARGUMENTS_ERROR)))
259259
end

python_package/brainflow/data_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def get_railed_percentage(cls, data: NDArray[Float64], gain: int):
771771

772772
@classmethod
773773
def get_oxygen_level(cls, ppg_ir: NDArray[Float64], ppg_red: NDArray[Float64], sampling_rate: int,
774-
coef1=0.0, coef2=-37.663, coef3=114.91):
774+
coef1=1.5958422, coef2=-34.6596622, coef3=112.6898759):
775775
"""get oxygen level from ppg
776776
777777
:param ppg_ir: input array
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
#include "mindfulness_model.h"
33
// clang-format off
4-
const double mindfulness_coefficients[5] = {2.6338144674136394,4.006742906593334,-34.51389221061297,1.1950604401540308,35.78022137767881};
5-
double mindfulness_intercept = 0.364078;
4+
const double mindfulness_coefficients[5] = {-1.4765769283767163,2.6620930328900974,-31.43942997194057,9.464066586727622,45.512106420941684};
5+
double mindfulness_intercept = 0.000000;
66
// clang-format on
0 Bytes
Binary file not shown.

src/ml/train/train_classifiers.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def prepare_data(first_class, second_class, blacklisted_channels=None):
4242
dataset_y = list()
4343
for data_type in (first_class, second_class):
4444
for file in glob.glob(os.path.join('data', data_type, '*', '*.csv')):
45+
data_x_temp = list()
46+
data_y_temp = list()
4547
logging.info(file)
4648
board_id = os.path.basename(os.path.dirname(file))
4749
try:
@@ -58,13 +60,17 @@ def prepare_data(first_class, second_class, blacklisted_channels=None):
5860
feature_vector = bands[0]
5961
feature_vector = feature_vector.astype(float)
6062
dataset_x.append(feature_vector)
63+
data_x_temp.append(feature_vector)
6164
if data_type == first_class:
6265
dataset_y.append(0)
66+
data_y_temp.append(0)
6367
else:
6468
dataset_y.append(1)
69+
data_y_temp.append(0)
6570
cur_pos = cur_pos + int(window_size * overlaps[num] * sampling_rate)
6671
except Exception as e:
6772
logging.error(str(e), exc_info=True)
73+
print_dataset_info((data_x_temp, data_y_temp))
6874

6975
logging.info('1st Class: %d 2nd Class: %d' % (len([x for x in dataset_y if x == 0]), len([x for x in dataset_y if x == 1])))
7076

@@ -115,7 +121,7 @@ def print_dataset_info(data):
115121

116122
def train_regression_mindfulness(data):
117123
model = LogisticRegression(solver='liblinear', max_iter=4000,
118-
penalty='l2', random_state=2, fit_intercept=True, intercept_scaling=0.2)
124+
penalty='l2', random_state=2, fit_intercept=False, intercept_scaling=3)
119125
logging.info('#### Logistic Regression ####')
120126
scores = cross_val_score(model, data[0], data[1], cv=5, scoring='f1_macro', n_jobs=8)
121127
logging.info('f1 macro %s' % str(scores))
@@ -207,14 +213,14 @@ def main():
207213
dataset_y = pickle.load(f)
208214
data = dataset_x, dataset_y
209215
else:
210-
data = prepare_data('relaxed', 'focused', blacklisted_channels={'T3', 'T4'})
216+
data = prepare_data('relaxed', 'focused')
211217
print_dataset_info(data)
212218
train_regression_mindfulness(data)
213-
train_svm_mindfulness(data)
214-
train_knn_mindfulness(data)
215-
train_random_forest_mindfulness(data)
216-
train_mlp_mindfulness(data)
217-
train_stacking_classifier(data)
219+
#train_svm_mindfulness(data)
220+
#train_knn_mindfulness(data)
221+
#train_random_forest_mindfulness(data)
222+
#train_mlp_mindfulness(data)
223+
#train_stacking_classifier(data)
218224

219225

220226
if __name__ == '__main__':

0 commit comments

Comments
 (0)