Skip to content

Commit c558c44

Browse files
committed
add FlattenLayerOp
1 parent 0289be9 commit c558c44

File tree

5 files changed

+58
-1
lines changed

5 files changed

+58
-1
lines changed

docs/src/lib/Architecture.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ AbstractPoolingLayerOp
6363

6464
```@docs
6565
DenseLayerOp
66+
FlattenLayerOp
6667
MaxPoolingLayerOp
6768
MeanPoolingLayerOp
6869
```

src/Architecture/Architecture.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@ using Requires
99
using Statistics: mean
1010

1111
export AbstractNeuralNetwork, FeedforwardNetwork,
12-
AbstractLayerOp, DenseLayerOp,
12+
AbstractLayerOp, DenseLayerOp, FlattenLayerOp,
1313
AbstractPoolingLayerOp, MaxPoolingLayerOp, MeanPoolingLayerOp,
1414
layers, dim_in, dim_out,
1515
ActivationFunction, Id, ReLU, Sigmoid, Tanh, LeakyReLU
1616

1717
include("ActivationFunction.jl")
1818
include("LayerOps/AbstractLayerOp.jl")
1919
include("LayerOps/DenseLayerOp.jl")
20+
include("LayerOps/FlattenLayerOp.jl")
2021
include("LayerOps/PoolingLayerOp.jl")
2122
include("NeuralNetworks/AbstractNeuralNetwork.jl")
2223
include("NeuralNetworks/FeedforwardNetwork.jl")
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
FlattenLayerOp <: AbstractLayerOp
3+
4+
A flattening layer operation converts a multidimensional tensor into a vector.
5+
6+
### Notes
7+
8+
The implementation uses row-major ordering for convenience with the
9+
machine-learning literature.
10+
11+
```@jldoctest
12+
julia> T = reshape([1, 3, 2, 4, 5, 7, 6, 8], (2, 2, 2))
13+
2×2×2 Array{Int64, 3}:
14+
[:, :, 1] =
15+
1 2
16+
3 4
17+
18+
[:, :, 2] =
19+
5 6
20+
7 8
21+
22+
julia> FlattenLayerOp()(T)
23+
8-element Vector{Int64}:
24+
1
25+
2
26+
3
27+
4
28+
5
29+
6
30+
7
31+
8
32+
```
33+
"""
34+
struct FlattenLayerOp <: AbstractLayerOp
35+
end
36+
37+
# application to a vector (swap to row-major convention)
38+
function (L::FlattenLayerOp)(T)
39+
s = size(T)
40+
if length(s) == 1
41+
return vec(T)
42+
end
43+
return vec(permutedims(T, (2, 1, 3:length(s)...)))
44+
end

test/Architecture/FlattenLayerOp.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
L = FlattenLayerOp()
2+
3+
# output for tensor `T`
4+
@test L([1 2; 3 4]) == [1, 2, 3, 4]
5+
@test L([1]) == [1]
6+
7+
# equality
8+
@test L == FlattenLayerOp()

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ struct TestActivation <: ActivationFunction end
1616
@testset "DenseLayerOp" begin
1717
include("Architecture/DenseLayerOp.jl")
1818
end
19+
@testset "FlattenLayerOp" begin
20+
include("Architecture/FlattenLayerOp.jl")
21+
end
1922
@testset "PoolingLayerOp" begin
2023
include("Architecture/PoolingLayerOp.jl")
2124
end

0 commit comments

Comments
 (0)