Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion test/deeponet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,21 @@ using Test, Random, Flux
# Accept only Int as architecture parameters
@test_throws MethodError DeepONet((32.5,64,72), (24,48,72), σ, tanh)
@test_throws MethodError DeepONet((32,64,72), (24.1,48,72))
end
end

#Just the first 16 datapoints from the Burgers' equation dataset
a = [0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755, 0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651, 0.81943734, 0.81737952, 0.8152405, 0.81302771]
sensors = collect(range(0, 1, length=16))'

model = DeepONet((16, 22, 30), (1, 16, 24, 30), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false)

model(a,sensors)

#forward pass
@test size(model(a, sensors)) == (1, 16)

mgrad = Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)

#gradients
@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[1])
@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[2])
25 changes: 24 additions & 1 deletion test/fourierlayer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,27 @@ using Test, Random, Flux
# Test max amount of modes
@test_throws AssertionError FourierLayer(100, 100, 100, 60, σ)
@test_throws AssertionError FourierLayer(100, 100, 100, 60)
end
end

#Just the first 16 data points from Burgers' equation dataset
xtrain = Float32[0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755, 0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651, 0.81943734, 0.81737952, 0.8152405, 0.81302771]
grid = Float32.(collect(range(0, 1, length=16))')

x = cat(reshape(xtrain,(1,16,1)),
reshape(repeat(grid,1),(1,16,1));
dims=3)

x = permutedims(x,(3,2,1))
layer = FourierLayer(64, 64, 16, 8, gelu, bias_fourier=false)
model = Chain(Dense(2,64;bias=false), layer, layer, layer, layer,
Dense(64,2;bias=false))

model(x)

#forward pass
@test size(model(x)) == (2, 16, 1)

Flux.Zygote.gradient((x)->sum(model(x)), x)

#gradient test
@test !iszero(Flux.Zygote.gradient((x)->sum(model(x)), x)[1])