Skip to content

Commit d8e8c14

Browse files
authored
type stable prediction of FluxModelApplicator (#24)
* type stable prediction of FluxModelApplicator * implement and test a type-stable rebuild FluxApplicator test show no significant speedup
1 parent 7e1527e commit d8e8c14

File tree

6 files changed

+65
-20
lines changed

6 files changed

+65
-20
lines changed

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,37 @@ struct FluxApplicator{RT} <: AbstractModelApplicator
1010
rebuild::RT
1111
end
1212

13+
struct PartricFluxApplicator{RT, MT, YT} <: AbstractModelApplicator
14+
rebuild::RT
15+
end
16+
17+
const FluxApplicatorU{RT} = Union{FluxApplicator{RT},PartricFluxApplicator{RT}} where RT
18+
19+
1320
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type::DataType)
1421
# TODO: care fore rng and float_type
1522
ϕ, rebuild = Flux.destructure(m)
1623
FluxApplicator(rebuild), ϕ
1724
end
1825

19-
function HVI.apply_model(app::FluxApplicator, x, ϕ)
26+
function HVI.apply_model(app::FluxApplicator, x::T, ϕ) where T
27+
# assume no size informmation in x -> can hint the type of the result
28+
# to be the same as the type of the input
2029
m = app.rebuild(ϕ)
21-
res = m(x)
30+
res = m(x)::T
31+
res
32+
end
33+
34+
35+
function HVI.construct_partric(app::FluxApplicator{RT}, x, ϕ) where RT
36+
m = app.rebuild(ϕ)
37+
y = m(x)
38+
PartricFluxApplicator{RT, typeof(m), typeof(y)}(app.rebuild)
39+
end
40+
41+
function HVI.apply_model(app::PartricFluxApplicator{RT, MT, YT}, x, ϕ) where {RT, MT, YT}
42+
m = app.rebuild(ϕ)::MT
43+
res = m(x)::YT
2244
res
2345
end
2446

@@ -66,4 +88,5 @@ end
6688

6789

6890

91+
6992
end # module

src/HybridVariationalInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ include("bijectors_utils.jl")
2828
export AbstractComponentArrayInterpreter, ComponentArrayInterpreter,
2929
StaticComponentArrayInterpreter
3030
export flatten1, get_concrete, get_positions, stack_ca_int, compose_interpreters
31+
export construct_partric
3132
include("ComponentArrayInterpreter.jl")
3233

3334
export AbstractModelApplicator, construct_ChainsApplicator

src/ModelApplicator.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ function apply_model(app::NullModelApplicator, x, ϕ)
3434
return x
3535
end
3636

37+
"""
38+
Construct a parametric type-stable model applicator, given
39+
covariates, `x`, and parameters, `ϕ`.
40+
41+
The default returns the current model applicator.
42+
"""
43+
function construct_partric(app::AbstractModelApplicator, x, ϕ)
44+
app
45+
end
46+
3747

3848
"""
3949
construct_ChainsApplicator([rng::AbstractRNG,] chain, float_type)

src/elbo.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT;
296296
xMP0 = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices)
297297
#Main.@infiltrate_main
298298

299-
μ_ζMs0 = g(xMP0, ϕg)::MT # for gpu restructure returns Any, so apply type
299+
μ_ζMs0 = g(xMP0, ϕg)
300300
ζP_resids, ζMs_parfirst_resids, σ = sample_ζresid_norm(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_ends, int_unc)
301301
if pbm_covar_indices isa SA.SVector{0}
302302
# do not need to predict again but just add the residuals to μ_ζP and μ_ζMs
@@ -308,7 +308,7 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT;
308308
ζP = μ_ζP .+ rP
309309
# second pass: append ζP rather than μ_ζP to covars to xM
310310
xMP = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices)
311-
μ_ζMst = g(xMP, ϕg)::MT # for gpu restructure returns Any, so apply type
311+
μ_ζMst = g(xMP, ϕg)
312312
ζMs = (μ_ζMst .+ rMs)' # already transform to par-last form
313313
ζP, ζMs
314314
end
@@ -356,26 +356,27 @@ function get_pbm_covar_indices(ζP, pbm_covars::NTuple{0},
356356
SA.SA[]
357357
end
358358

359-
# function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{N,Symbol}, g, ϕg, μ_ζMs0) where N
360-
# xMP2 = _append_PBM_covars(xM, ζP, pbm_covars) # need different variable name?
359+
# remove?
360+
# # function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{N,Symbol}, g, ϕg, μ_ζMs0) where N
361+
# # xMP2 = _append_PBM_covars(xM, ζP, pbm_covars) # need different variable name?
362+
# # μ_ζMs = g(xMP2, ϕg)
363+
# # end
364+
# # function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{0}, g, ϕg, μ_ζMs0)
365+
# # # if pbm_covars is the empty tuple, just return the original prediction on xM only
366+
# # # rather than calling the ML model
367+
# # μ_ζMs0
368+
# # end
369+
370+
# function _predict_μ_ζMs(xM, ζP, pbm_covar_indices::AbstractVector, g, ϕg, μ_ζMs0)
371+
# xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices)
361372
# μ_ζMs = g(xMP2, ϕg)
362373
# end
363-
# function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{0}, g, ϕg, μ_ζMs0)
374+
# function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕg, μ_ζMs0)
364375
# # if pbm_covars is the empty tuple, just return the original prediction on xM only
365376
# # rather than calling the ML model
366377
# μ_ζMs0
367378
# end
368379

369-
function _predict_μ_ζMs(xM, ζP, pbm_covar_indices::AbstractVector, g, ϕg, μ_ζMs0)
370-
xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices)
371-
μ_ζMs = g(xMP2, ϕg)
372-
end
373-
function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕg, μ_ζMs0)
374-
# if pbm_covars is the empty tuple, just return the original prediction on xM only
375-
# rather than calling the ML model
376-
μ_ζMs0
377-
end
378-
379380
"""
380381
Extract relevant parameters from ζ and return n_MC generated multivariate normal draws
381382
together with the vector of standard deviations, `σ`: `(ζP_resids, ζMs_parfirst_resids, σ)`

src/gf.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ end
120120
composition transM ∘ g: transformation after machine learning parameter prediction
121121
Provide a `transMs = StackedArray(transM, n_batch)`
122122
"""
123-
function gtrans(g, transMs, xMP::T, ϕg; cdev) where T
123+
function gtrans(g, transMs, xMP, ϕg; cdev)
124124
# TODO remove after removing gf
125125
# predict the log of the parameters
126-
ζMst = g(xMP, ϕg)::T # problem of Flux model applicator restructure
126+
ζMst = g(xMP, ϕg)
127127
ζMs = ζMst'
128128
ζMs_cpu = cdev(ζMs)
129129
θMs = transMs(ζMs_cpu)

test/test_Flux.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,19 @@ using Flux
5151
n_site = 3
5252
x = rand(Float32, n_covar, n_site) |> gpu
5353
ϕ = ϕg |> gpu
54-
y = g(x, ϕ)
54+
y = @inferred g(x, ϕ)
55+
# @usingany Cthulhu
56+
# @descend_code_warntype g(x, ϕ)
5557
#@test ϕ isa GPUArraysCore.AbstractGPUArray
5658
@test size(y) == (n_out, n_site)
59+
gp = construct_partric(g, x, ϕ)
60+
y2 = @inferred gp(x, ϕ)
61+
@test y2 == y
62+
() -> begin
63+
# @usingany BenchmarkTools
64+
#@benchmark g(x,ϕ)
65+
#@benchmark gp(x,ϕ) # no difference type-inferred
66+
end
5767
end;
5868

5969
@testset "cpu_ca" begin

0 commit comments

Comments
 (0)