Skip to content

Commit 3a825ac

Browse files
committed
add prediction only workflow
1 parent 57d7ebd commit 3a825ac

File tree

2 files changed

+133
-118
lines changed

2 files changed

+133
-118
lines changed

DockerizedAutoML/main.jl

Lines changed: 124 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -7,147 +7,155 @@ using Statistics
77

88

99
function parse_commandline()
10-
s = ArgParseSettings()
11-
@add_arg_table! s begin
12-
"--url", "-u"
13-
help = "mlflow server url"
14-
arg_type = String
15-
default = "http://localhost:8080"
16-
"--prediction_type", "-t"
17-
help = "classification, regression, anomalydetection"
18-
arg_type = String
19-
default = "classification"
20-
"--complexity", "-c"
21-
help = "pipeline complexity"
22-
arg_type = String
23-
default = "low"
24-
"--output_file", "-o"
25-
help = "output location"
26-
arg_type = String
27-
default = "NONE"
28-
"--nfolds", "-f"
29-
help = "number of crossvalidation folds"
30-
arg_type = Int64
31-
default = 3
32-
"--nworkers", "-w"
33-
help = "number of workers"
34-
arg_type = Int64
35-
default = 5
36-
"--no_save"
37-
help = "save model"
38-
action = :store_true
39-
"--predict_only"
40-
help = "no training, predict only"
41-
action = :store_true
42-
"--runid"
43-
help = "runid of experiment for trained model"
44-
arg_type = String
45-
default = "NONE"
46-
"csvfile"
47-
help = "input csv file"
48-
required = true
49-
end
50-
return parse_args(s; as_symbols=true)
10+
s = ArgParseSettings()
11+
@add_arg_table! s begin
12+
"--url", "-u"
13+
help = "mlflow server url"
14+
arg_type = String
15+
default = "http://localhost:8080"
16+
"--prediction_type", "-t"
17+
help = "classification, regression, anomalydetection"
18+
arg_type = String
19+
default = "classification"
20+
"--complexity", "-c"
21+
help = "pipeline complexity"
22+
arg_type = String
23+
default = "low"
24+
"--output_file", "-o"
25+
help = "output location"
26+
arg_type = String
27+
default = "NONE"
28+
"--nfolds", "-f"
29+
help = "number of crossvalidation folds"
30+
arg_type = Int64
31+
default = 3
32+
"--nworkers", "-w"
33+
help = "number of workers"
34+
arg_type = Int64
35+
default = 5
36+
"--no_save"
37+
help = "save model"
38+
action = :store_true
39+
"--predict_only"
40+
help = "no training, predict only"
41+
action = :store_true
42+
"--runid"
43+
help = "runid of experiment for trained model"
44+
arg_type = String
45+
default = "NONE"
46+
"csvfile"
47+
help = "input csv file"
48+
required = true
49+
end
50+
return parse_args(s; as_symbols=true)
5151
end
5252

5353
const _cliargs = parse_commandline()
5454
const _workers = _cliargs[:nworkers]
5555

5656
if _cliargs[:predict_only] == false
57-
nprocs() == 1 && addprocs(_workers; exeflags=["--project=$(Base.active_project())"])
58-
@everywhere using AutoAI
57+
nprocs() == 1 && addprocs(_workers; exeflags=["--project=$(Base.active_project())"])
58+
@everywhere using AutoAI
5959
end
6060

6161
function autoclassmode(args::Dict)
62-
url = args[:url]
63-
complexity = args[:complexity]
64-
nfolds = args[:nfolds]
65-
nworkers = args[:nworkers]
66-
prediction_type = args[:prediction_type]
67-
impl_args = (; complexity, nfolds, nworkers, prediction_type) |> pairs |> Dict
68-
fname = _cliargs[:csvfile]
69-
df = CSV.read(fname, DataFrame)
70-
X = df[:, 1:end-1]
71-
Y = df[:, end] |> collect
72-
autoclass = AutoMLFlowClassification(Dict(:url => url, :impl_args => impl_args))
73-
Yc = fit_transform!(autoclass, X, Y)
74-
println("accuracy = ", mean(Y .== Yc))
75-
return autoclass
62+
url = args[:url]
63+
complexity = args[:complexity]
64+
nfolds = args[:nfolds]
65+
nworkers = args[:nworkers]
66+
prediction_type = args[:prediction_type]
67+
impl_args = (; complexity, nfolds, nworkers, prediction_type) |> pairs |> Dict
68+
fname = _cliargs[:csvfile]
69+
df = CSV.read(fname, DataFrame)
70+
X = df[:, 1:end-1]
71+
Y = df[:, end] |> collect
72+
autoclass = AutoMLFlowClassification(Dict(:url => url, :impl_args => impl_args))
73+
Yc = fit_transform!(autoclass, X, Y)
74+
println("accuracy = ", mean(Y .== Yc))
75+
return autoclass
7676
end
7777

7878
function autoregmode(args::Dict)
79-
url = args[:url]
80-
complexity = args[:complexity]
81-
nfolds = args[:nfolds]
82-
nworkers = args[:nworkers]
83-
prediction_type = args[:prediction_type]
84-
impl_args = (; complexity, nfolds, nworkers, prediction_type) |> pairs |> Dict
85-
fname = _cliargs[:csvfile]
86-
df = CSV.read(fname, DataFrame)
87-
X = df[:, 1:end-1]
88-
Y = df[:, end] |> collect
89-
autoreg = AutoMLFlowRegression(Dict(:url => url, :impl_args => impl_args))
90-
Yc = fit_transform!(autoreg, X, Y)
91-
println("mse = ", mean((Y - Yc) .^ 2))
92-
return autoreg
79+
url = args[:url]
80+
complexity = args[:complexity]
81+
nfolds = args[:nfolds]
82+
nworkers = args[:nworkers]
83+
prediction_type = args[:prediction_type]
84+
impl_args = (; complexity, nfolds, nworkers, prediction_type) |> pairs |> Dict
85+
fname = _cliargs[:csvfile]
86+
df = CSV.read(fname, DataFrame)
87+
X = df[:, 1:end-1]
88+
Y = df[:, end] |> collect
89+
autoreg = AutoMLFlowRegression(Dict(:url => url, :impl_args => impl_args))
90+
Yc = fit_transform!(autoreg, X, Y)
91+
println("mse = ", mean((Y - Yc) .^ 2))
92+
return autoreg
9393
end
9494

9595
function doprediction_only(args::Dict)
96-
fname = args[:csvfile]
97-
X = CSV.read(fname, DataFrame)
98-
run_id = args[:runid]
99-
url = args[:url]
100-
mlf = AutoMLFlowClassification(Dict(:run_id => run_id, :url => url))
101-
Yn = transform!(mlf, X)
102-
ofile = args[:output_file]
103-
if ofile != "NONE"
104-
open(ofile, "w") do stfile
105-
println(stfile, "prediction: $Yn")
106-
println(stdout, "prediction: $Yn")
96+
fname = args[:csvfile]
97+
X = CSV.read(fname, DataFrame)
98+
run_id = args[:runid]
99+
url = args[:url]
100+
mlf =
101+
predtype = args[:prediction_type]
102+
mlf = if predtype == "classification"
103+
AutoMLFlowClassification(Dict(:run_id => run_id, :url => url))
104+
elseif predtype == "regression"
105+
AutoMLFlowRegression(Dict(:run_id => run_id, :url => url))
106+
else
107+
error("unknown predtype option")
108+
end
109+
Yn = transform!(mlf, X)
110+
ofile = args[:output_file]
111+
if ofile != "NONE"
112+
open(ofile, "w") do stfile
113+
println(stfile, "prediction: $Yn")
114+
println(stdout, "prediction: $Yn")
115+
end
116+
else
117+
println(stdout, "prediction: $Yn")
107118
end
108-
else
109-
println(stdout, "prediction: $Yn")
110-
end
111-
return Yn
119+
return Yn
112120
end
113121

114122
function printsummary(io::IO, automl::Workflow)
115-
r(x) = round(x, digits=2)
116-
trainedmodel = automl.model[:automodel]
117-
bestmodel = trainedmodel.model[:bestpipeline].model[:description]
118-
println(io, "pipelines: $(trainedmodel.model[:dfpipelines].Description)")
119-
println(io, "best_pipeline: $bestmodel")
120-
bestmean = trainedmodel.model[:performance].mean[1]
121-
bestsd = trainedmodel.model[:performance].sd[1]
122-
println(io, "best_pipeline_performance: $(r(bestmean)) ± $(r(bestsd))")
123+
r(x) = round(x, digits=2)
124+
trainedmodel = automl.model[:automodel]
125+
bestmodel = trainedmodel.model[:bestpipeline].model[:description]
126+
println(io, "pipelines: $(trainedmodel.model[:dfpipelines].Description)")
127+
println(io, "best_pipeline: $bestmodel")
128+
bestmean = trainedmodel.model[:performance].mean[1]
129+
bestsd = trainedmodel.model[:performance].sd[1]
130+
println(io, "best_pipeline_performance: $(r(bestmean)) ± $(r(bestsd))")
123131
end
124132

125133
function dotrainandpredict(args::Dict)
126-
# train model
127-
predtype = args[:prediction_type]
128-
automl = if predtype == "classification"
129-
autoclassmode(args)
130-
elseif predtype == "regression"
131-
autoregmode(args)
132-
end
133-
ofile = args[:output_file]
134-
if ofile != "NONE"
135-
open(ofile, "w") do stfile
136-
printsummary(stfile, automl)
137-
printsummary(stdout, automl)
134+
# train model
135+
predtype = args[:prediction_type]
136+
automl = if predtype == "classification"
137+
autoclassmode(args)
138+
elseif predtype == "regression"
139+
autoregmode(args)
140+
end
141+
ofile = args[:output_file]
142+
if ofile != "NONE"
143+
open(ofile, "w") do stfile
144+
printsummary(stfile, automl)
145+
printsummary(stdout, automl)
146+
end
147+
else
148+
printsummary(stdout, automl)
138149
end
139-
else
140-
printsummary(stdout, automl)
141-
end
142150
end
143151

144152
function main(args::Dict)
145-
if args[:predict_only] == true
146-
# predict only using run_id of model in the artifact
147-
doprediction_only(args)
148-
else
149-
# train and predict
150-
dotrainandpredict(args)
151-
end
153+
if args[:predict_only] == true
154+
# predict only using run_id of model in the artifact
155+
doprediction_only(args)
156+
else
157+
# train and predict
158+
dotrainandpredict(args)
159+
end
152160
end
153161
main(_cliargs)

DockerizedAutoML/run.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
docker build -t automlai --platform=linux/amd64 .
2-
docker run -it --rm --platform=linux/amd64 automlai
1+
docker build -t automlai:v2.0 --platform=linux/amd64 .
2+
docker run -it --rm --platform=linux/amd64 automlai:v2.0
33

44
# julia --project -- ./main.jl -c high -t regression -f 3 -w 7 iris_reg.csv
55
# julia --project -- ./main.jl -c low -t classification -f 3 -w 3 iris.csv
66
# julia --project -- ./main.jl -c low -t anomalydetection iris.csv
77
# podman run -it --rm --platform=linux/amd64 localhost/automlai -u http://spendor2.sl.cloud9.ibm.com:30412 iris.csv
88
# podman run -it --rm -v `pwd`:/data/ localhost/automlai -u http://spendor2.sl.cloud9.ibm.com:30412 -t regression /data/iris_reg.csv
99
# julia --project -- ./main.jl -c low -t classification -f 3 -w 3 iris.csv --predict_only --runid cd4e463d6a414aa4aaad173e567d7d22 -o /tmp/hello.txt
10+
11+
julia --project -- ./main.jl -t regression --predict_only -u http://mlflow.isiath.duckdns.org:8082 --runid 064fb7a188d34a3da87f2271b8d8d9c2 -o /tmp/reg.txt ./iris_reg.csv
12+
julia --project -- ./main.jl -u http://mlflow.isiath.duckdns.org:8082 -t classification --predict_only --runid e33bbd5c12a54756b1333df1f23a8366 -o /tmp/class.txt ./iris.csv
13+
14+
docker run -it --rm -v `pwd`:/data/ localhost/automlai -u http://mlflow.isiath.duckdns.org:8082 -t classification --predict_only --runid e33bbd5c12a54756b1333df1f23a8366 /data/iris.csv
15+
16+
docker run -it --rm -v `pwd`:/data/ localhost/automlai -u http://mlflow.isiath.duckdns.org:8082 -t regression --predict_only --runid 064fb7a188d34a3da87f2271b8d8d9c2 /data/iris_reg.csv

0 commit comments

Comments
 (0)