Skip to content

Commit 22e3a1d

Browse files
committed
add test for ComponentArrayInterpreter
and fix edge-case ambuiguity for empty ComponentVector put array adding empty dimesnions
1 parent 37226f4 commit 22e3a1d

File tree

6 files changed

+118
-14
lines changed

6 files changed

+118
-14
lines changed

Project.toml

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,33 @@ uuid = "a108c475-a4e2-4021-9a84-cfa7df242f64"
33
authors = ["Thomas Wutzler <twutz@bgc-jena.mpg.de> and contributors"]
44
version = "1.0.0-DEV"
55

6-
[weakdeps]
7-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
8-
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
9-
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
10-
116
[deps]
127
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
138
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
149
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1510
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1611
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1712

13+
[weakdeps]
14+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
15+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
16+
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
17+
18+
[extensions]
19+
HybridVariationalInferenceFluxExt = "Flux"
20+
HybridVariationalInferenceLuxExt = "Lux"
21+
HybridVariationalInferenceSimpleChainsExt = "SimpleChains"
22+
1823
[compat]
1924
Combinatorics = "1.0.2"
2025
ComponentArrays = "0.15.19"
2126
Flux = "v0.15.2"
2227
Lux = "1.4.2"
2328
Random = "1.10.0"
2429
SimpleChains = "0.4"
30+
StatsBase = "0.34.4"
31+
StatsFuns = "1.3.2"
2532
julia = "1.10"
2633

27-
[extensions]
28-
HybridVariationalInferenceSimpleChainsExt = "SimpleChains"
29-
HybridVariationalInferenceFluxExt = "Flux"
30-
HybridVariationalInferenceLuxExt = "Lux"
31-
3234
[workspace]
33-
projects = ["test", "docs"]
35+
projects = ["test", "docs"]

src/ComponentArrayInterpreter.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ function ComponentArrayInterpreter(
137137
ComponentArrayInterpreter(axes_ext)
138138
end
139139

140+
# ambuiguity with two empty Tuples (edge case that does not make sense)
141+
# Empty ComponentVector with no other array dimenstions -> empty componentVector
142+
function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{})
143+
ComponentArrayInterpreter(CA.ComponentVector())
144+
end
145+
146+
140147

141148

142149
# not exported, but required for testing

src/HybridVariationalInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Random
55
using StatsBase # fit ZScoreTransform
66
using Combinatorics # gen_hybridcase_synthetic/combinations
77

8-
export ComponentArrayInterpreter, flatten1
8+
export ComponentArrayInterpreter, flatten1, get_concrete
99
include("ComponentArrayInterpreter.jl")
1010

1111
export AbstractModelApplicator, construct_SimpleChainsApplicator, construct_FluxApplicator,

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ const GROUP = get(ENV, "GROUP", "All") # defined in in CI.yml
33

44
@time begin
55
if GROUP == "All" || GROUP == "Basic"
6+
#@safetestset "test" include("test/test_ComponentArrayInterpreter.jl")
7+
@time @safetestset "test_ComponentArrayInterpreter" include("test_ComponentArrayInterpreter.jl")
68
#@safetestset "test" include("test/test_gencovar.jl")
79
@time @safetestset "test_gencovar" include("test_gencovar.jl")
810
#@safetestset "test" include("test/test_SimpleChains.jl")
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
using Test
2+
using HybridVariationalInference
3+
using HybridVariationalInference: HybridVariationalInference as CM
4+
using ComponentArrays: ComponentArrays as CA
5+
6+
@testset "ComponentArrayInterpreter vector" begin
7+
component_counts = comp_cnts = (; P=2, M=3, Unc=5)
8+
m = ComponentArrayInterpreter(; comp_cnts...)
9+
testm = (m) -> begin
10+
@test CM._get_ComponentArrayInterpreter_axes(m) == (CA.Axis(P=1:2, M=3:5, Unc=6:10),)
11+
@test length(m) == 10
12+
v = 1:length(m)
13+
cv = m(v)
14+
@test cv.Unc == 6:10
15+
end
16+
testm(m)
17+
testm(get_concrete(m))
18+
Base.isconcretetype(typeof(m))
19+
end;
20+
21+
# () -> begin
22+
# # test generate code for length
23+
# @code_llvm length(m)
24+
# mc = get_concrete(m)
25+
# @code_llvm length(mc)
26+
# v = 1:length(m)
27+
# @code_llvm as_ca(v,m)
28+
# @code_llvm as_ca(v,mc)
29+
# end
30+
31+
@testset "ComponentArrayInterpreter matrix in vector" begin
32+
component_shapes = (; P=2, M=(2, 3), Unc=5)
33+
m = ComponentArrayInterpreter(; component_shapes...)
34+
testm = (m) -> begin
35+
@test length(m) == 13
36+
a = 1:length(m)
37+
cv = m(a)
38+
@test cv.M == 2 .+ [1 3 5; 2 4 6]
39+
end
40+
testm(m)
41+
testm(get_concrete(m))
42+
end;
43+
44+
@testset "ComponentArrayInterpreter matrix and array" begin
45+
mv = ComponentArrayInterpreter(; c1=2, c2=3)
46+
cv = mv(1:length(mv))
47+
n_col = 4
48+
mm = ComponentArrayInterpreter(cv, (n_col,)) # 1-tuple
49+
testm = (m) -> begin
50+
@test length(mm) == length(cv) * n_col
51+
cm = mm(1:length(mm))
52+
#cm[:c1,:]
53+
@test cm[:c1, 2] == 6:7
54+
end
55+
testm(mm)
56+
mmc = get_concrete(mm)
57+
testm(mmc)
58+
#
59+
n_z = 3
60+
mm = ComponentArrayInterpreter(cv, (n_col, n_z))
61+
testm = (m) -> begin
62+
@test length(mm) == length(cv) * n_col * n_z
63+
cm = mm(1:length(mm))
64+
@test cm[:c1, 2, 2] == 26:27
65+
end
66+
testm(mm)
67+
testm(get_concrete(mm))
68+
#
69+
n_row = 3
70+
mm = ComponentArrayInterpreter((n_row,), cv)
71+
testm = (m) -> begin
72+
@test length(mm) == n_row * length(cv)
73+
cm = mm(1:length(mm))
74+
@test cm[2, :c1] == [2, 5]
75+
end
76+
testm(mm)
77+
testm(get_concrete(mm))
78+
end;
79+
80+
@testset "empty ComponentVector" begin
81+
x = CA.ComponentVector{Float32}()
82+
int1 = ComponentArrayInterpreter(x)
83+
@test int1(CA.getdata(x)) == x
84+
int2 = ComponentArrayInterpreter(x, ())
85+
@test int2 == int1
86+
int3 = ComponentArrayInterpreter((), x)
87+
@test int3 == int1
88+
end;
89+
90+
91+

test/test_doubleMM.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ end
5050
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1],
5151
Optimization.AutoZygote())
5252
optprob = Optimization.OptimizationProblem(optf, ϕg0);
53-
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600);
53+
#res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600);
54+
res = Optimization.solve(optprob, Adam(0.02), maxiters = 600);
5455

5556
ϕg_opt1 = res.u;
5657
pred = loss_g(ϕg_opt1, xM, g)
@@ -79,7 +80,8 @@ end
7980
optprob = OptimizationProblem(optf, p0, train_loader)
8081

8182
res = Optimization.solve(
82-
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
83+
# optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
84+
optprob, Adam(0.02), maxiters = 1000);
8385

8486
l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...)
8587
@test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11)

0 commit comments

Comments
 (0)