Skip to content

Commit 216a8ac

Browse files
committed
add 'size' for each layer type
1 parent 3c4933f commit 216a8ac

File tree

11 files changed

+39
-1
lines changed

11 files changed

+39
-1
lines changed

src/Architecture/Architecture.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Requires
99
using LinearAlgebra: dot
1010
using Statistics: mean
1111

12+
import Base: size
1213
export AbstractNeuralNetwork, FeedforwardNetwork,
1314
AbstractLayerOp, DenseLayerOp, ConvolutionalLayerOp, FlattenLayerOp,
1415
AbstractPoolingLayerOp, MaxPoolingLayerOp, MeanPoolingLayerOp,

src/Architecture/LayerOps/ConvolutionalLayerOp.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ function Base.show(io::IO, L::ConvolutionalLayerOp)
115115
return print(io, str)
116116
end
117117

118+
size(::ConvolutionalLayerOp) = (3, 3)
119+
118120
function load_Flux_convert_Conv_layer()
119121
return quote
120122
function Base.convert(::Type{ConvolutionalLayerOp}, layer::Flux.Conv)

src/Architecture/LayerOps/DenseLayerOp.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ dim_in(L::DenseLayerOp) = size(L.weights, 2)
7373

7474
dim_out(L::DenseLayerOp) = length(L.bias)
7575

76+
size(::DenseLayerOp) = (1, 1)
77+
7678
function load_Flux_convert_Dense_layer()
7779
return quote
7880
function Base.convert(::Type{DenseLayerOp}, layer::Flux.Dense)

src/Architecture/LayerOps/FlattenLayerOp.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,5 @@ function (L::FlattenLayerOp)(T)
4242
end
4343
return vec(permutedims(T, (2, 1, 3:length(s)...)))
4444
end
45+
46+
size(::FlattenLayerOp) = (nothing, 1)

src/Architecture/LayerOps/PoolingLayerOp.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,5 @@ function (L::AbstractPoolingLayerOp)(T)
9090
end
9191
return O
9292
end
93+
94+
size(::AbstractPoolingLayerOp) = (3, 3)

src/Architecture/NeuralNetworks/FeedforwardNetwork.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,20 @@ end
3232
function _first_inconsistent_layer(L)
3333
prev = nothing
3434
for (i, l) in enumerate(L)
35-
if !isnothing(prev) && dim_in(l) != dim_out(prev)
35+
if !isnothing(prev) &&
36+
((!isnothing(dim_in(l)) && !isnothing(dim_out(prev)) && dim_in(l) != dim_out(prev)) ||
37+
!_iscompatible(size(prev), size(l)))
3638
return i
3739
end
3840
prev = l
3941
end
4042
return 0
4143
end
4244

45+
_iscompatible(t1::Tuple, t2::Tuple) = _iscompatible(t1[2], t2[1])
46+
_iscompatible(i::Int, j::Int) = i == j
47+
_iscompatible(i, ::Nothing) = true
48+
4349
layers(N::FeedforwardNetwork) = N.layers
4450

4551
function load_Flux_convert_network()

test/Architecture/ConvolutionalLayerOp.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ println(io, L)
5454
!(L ConvolutionalLayerOp(Ws, [b1 .+ 1], ReLU())) &&
5555
!(L ConvolutionalLayerOp(Ws, bs, Id()))
5656

57+
# size
58+
@test size(L) == (3, 3)
59+
5760
# kernel size and number of filters
5861
@test kernel(L) == kernel(L2) == (2, 2, 1)
5962
@test n_filters(L) == 1 && n_filters(L2) == 2

test/Architecture/DenseLayerOp.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ println(io, L)
4242
@test dim_out(L) == 3
4343
@test dim(L) == (2, 3)
4444
@test length(L) == 3
45+
@test size(L) == (1, 1)
4546

4647
# test methods for all activations
4748
function test_layer(L::DenseLayerOp{Id})

test/Architecture/FeedforwardNetwork.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,16 @@ println(io, N2)
4848
@test dim_in(N1) == 2 && dim_in(N2) == 2
4949
@test dim_out(N1) == 3 && dim_out(N2) == 2
5050
@test dim(N1) == (2, 3) && dim(N2) == (2, 2)
51+
52+
# network with all layer types
53+
L1 = ConvolutionalLayerOp([reshape([1 0; -1 2], (2, 2, 1))], [1], ReLU())
54+
L2 = MaxPoolingLayerOp(1, 1)
55+
L3 = FlattenLayerOp()
56+
W = zeros(2, 9); W[1, 1] = W[2, 2] = 1
57+
L4 = DenseLayerOp(W, [1.0 0], ReLU())
58+
N3 = FeedforwardNetwork([L1, L2, L3, L4])
59+
T441 = reshape([0 4 2 1; -1 0 1 -2; 3 1 2 0; 0 1 4 1], (4, 4, 1))
60+
@test N3(T441) == [3.0 2; 8 7]
61+
62+
# incompatible dimensions
63+
@test_throws ArgumentError FeedforwardNetwork([L1, L4])

test/Architecture/FlattenLayerOp.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ L = FlattenLayerOp()
66

77
# equality
88
@test L == FlattenLayerOp()
9+
10+
# size
11+
@test size(L) == (nothing, 1)

test/Architecture/PoolingLayerOp.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ for (L, LT) in zip(Ls, LTs)
3939
@test L != LT(2, 2) && L != LT(3, 3)
4040
end
4141
@test L1 != L2
42+
43+
# size
44+
@test size(L1) == (3, 3)

0 commit comments

Comments
 (0)