-
Notifications
You must be signed in to change notification settings - Fork 1
dt_bestguess defaulted to false #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
# using SoleModels | ||
# using MLJ | ||
# using DataFrames, Random | ||
# using DecisionTree | ||
# const DT = DecisionTree | ||
|
||
Comment on lines
+1
to
+6
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These seem useful actually? |
||
X, y = @load_iris | ||
X = DataFrame(X) | ||
|
||
|
@@ -27,7 +33,7 @@ model = Stump(; | |
# Bind the model and data into a machine | ||
mach = machine(model, X_train, y_train) | ||
# Fit the model | ||
fit!(mach, verbosity=0) | ||
MLJ.fit!(mach, verbosity=0) | ||
|
||
weights = mach.fitresult[2] | ||
classlabels = sort(mach.fitresult[3]) | ||
|
@@ -72,7 +78,7 @@ ada_accuracy = sum(preds .== y_test)/length(y_test) | |
Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree | ||
dt_model = Tree(max_depth=-1, min_samples_leaf=1, min_samples_split=2) | ||
dt_mach = machine(dt_model, X_train, y_train) | ||
fit!(dt_mach, verbosity=0) | ||
MLJ.fit!(dt_mach, verbosity=0) | ||
dt_solem = solemodel(fitted_params(dt_mach).tree) | ||
dt_preds = apply(dt_solem, X_test) | ||
dt_accuracy = sum(dt_preds .== y_test)/length(y_test) | ||
|
@@ -81,7 +87,7 @@ dt_accuracy = sum(dt_preds .== y_test)/length(y_test) | |
Forest = MLJ.@load RandomForestClassifier pkg=DecisionTree | ||
rm_model = Forest(; max_depth=3, min_samples_leaf=1, min_samples_split=2, n_trees=10, rng) | ||
rm_mach = machine(rm_model, X_train, y_train) | ||
fit!(rm_mach, verbosity=0) | ||
MLJ.fit!(rm_mach, verbosity=0) | ||
classlabels = (rm_mach).fitresult[2] | ||
classlabels = classlabels[sortperm((rm_mach).fitresult[3])] | ||
featurenames = report(rm_mach).features | ||
|
@@ -111,19 +117,19 @@ println("RandomForest accuracy: ", rm_accuracy) | |
# solemodel | ||
model = Stump(; n_iter, rng=Xoshiro(seed)) | ||
mach = machine(model, X_train, y_train) | ||
fit!(mach, verbosity=0) | ||
MLJ.fit!(mach, verbosity=0) | ||
weights = mach.fitresult[2] | ||
classlabels = sort(mach.fitresult[3]) | ||
featurenames = MLJ.report(mach).features | ||
solem = solemodel(MLJ.fitted_params(mach).stumps; weights, classlabels, featurenames) | ||
solem = solemodel(MLJ.fitted_params(mach).stumps; weights, classlabels, featurenames, dt_bestguess=true) | ||
preds = apply(solem, X_test) | ||
|
||
# decisiontree | ||
yl_train = CategoricalArrays.levelcode.(y_train) | ||
yl_train = MLJ.levelcode.(y_train) | ||
dt_model, dt_coeffs = DT.build_adaboost_stumps(yl_train, Matrix(X_train), n_iter; rng=Xoshiro(seed)) | ||
dt_preds = DT.apply_adaboost_stumps(dt_model, dt_coeffs, Matrix(X_test)) | ||
|
||
code_preds = CategoricalArrays.levelcode.(preds) | ||
code_preds = MLJ.levelcode.(preds) | ||
@test code_preds == dt_preds | ||
end | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
# using SoleModels | ||
# using MLJ | ||
# using DataFrames, Random | ||
# using DecisionTree | ||
# const DT = DecisionTree | ||
|
||
Comment on lines
+1
to
+6
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
X, y = @load_iris | ||
X = DataFrame(X) | ||
|
||
|
@@ -24,7 +30,7 @@ model = Tree( | |
# Bind the model and data into a machine | ||
mach = machine(model, X_train, y_train) | ||
# Fit the model | ||
fit!(mach) | ||
MLJ.fit!(mach) | ||
|
||
|
||
solem = solemodel(fitted_params(mach).tree) | ||
|
@@ -80,16 +86,16 @@ printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_d | |
# solemodel | ||
model = Tree(; max_depth, rng=Xoshiro(seed)) | ||
mach = machine(model, X_train, y_train) | ||
fit!(mach, verbosity=0) | ||
MLJ.fit!(mach, verbosity=0) | ||
solem = solemodel(MLJ.fitted_params(mach).tree) | ||
preds = apply!(solem, X_test, y_test) | ||
|
||
# decisiontree | ||
y_coded_train = @. CategoricalArrays.levelcode(y_train) | ||
y_coded_train = @. MLJ.levelcode(y_train) | ||
dt_model = DT.build_tree(y_coded_train, Matrix(X_train), 0, max_depth; rng=Xoshiro(seed)) | ||
dt_preds = DT.apply_tree(dt_model, Matrix(X_test)) | ||
|
||
preds_coded = CategoricalArrays.levelcode.(CategoricalArray(preds)) | ||
preds_coded = MLJ.levelcode.(MLJ.CategoricalArray(preds)) | ||
@test preds_coded == dt_preds | ||
end | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -170,9 +170,9 @@ branch_r = @test_nowarn Branch(formula_r, (branch_r, "yes")) | |
rule_r = @test_nowarn Rule(formula_r, branch_r) | ||
branch_r_mixed = @test_nowarn Branch(formula_r, (rule_r, "no")) | ||
|
||
dtmodel0 = @test_nowarn DecisionTree("1") | ||
dtmodel = @test_nowarn DecisionTree(branch_r) | ||
@test_nowarn DecisionTree(branch_r_mixed) | ||
dtmodel0 = @test_nowarn SoleModels.DecisionTree("1") | ||
dtmodel = @test_nowarn SoleModels.DecisionTree(branch_r) | ||
@test_nowarn SoleModels.DecisionTree(branch_r_mixed) | ||
Comment on lines
-173
to
+175
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mh why is this needed? Is there a clash on "DecisionTree"? Maybe with MLJ? |
||
# msmodel = MixedModel(dtmodel) | ||
|
||
complex_mixed_model = @test_nowarn Branch(formula_r, (dtmodel, dlmodel_integer)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay good to have a kwarg for this. Can we name it differently? Like I guess we could remove the "dt_" part, because we may want to have the same argument in other
SoleModels.solemodel
methods in other package extension unrelated with "decision trees".Maybe "alphanumeric_tiebreaker", or "argmax_tiebreaker", or "tiebreaker::Symbol" that can either be
:argmax
or:alphanumeric