diff --git a/Project.toml b/Project.toml index daf2656..072b398 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ControllerFormats" uuid = "02ac4b2c-022a-44aa-84a5-ea45a5754bcc" -version = "0.2.1" +version = "0.2.2" [deps] ReachabilityBase = "379f33d0-9447-4353-bd03-d664070e549f" diff --git a/src/FileFormats/ONNX.jl b/src/FileFormats/ONNX.jl index bdcb47f..a1be0c7 100644 --- a/src/FileFormats/ONNX.jl +++ b/src/FileFormats/ONNX.jl @@ -73,7 +73,10 @@ function read_ONNX(filename::String; input_dimension=nothing) idx += 1 end n_layers = div(idx - 2, 2) - @assert length(ops) == 4 * n_layers + # 4 operations per layer +1 for the input operation + # (-1 potentially for implicit identity activation in the last layer) + @assert length(ops) == 4 * n_layers || length(ops) == 4 * n_layers + 1 "" * + "each layer should consist of 4 operations (except possibly the last one)" T = DenseLayerOp{<:ActivationFunction,Matrix{Float32},Vector{Float32}} layers = T[] layer = 1 @@ -98,9 +101,16 @@ function read_ONNX(filename::String; input_dimension=nothing) op = ops[idx] @assert op isa Umlaut.Call "expected an activation function" args = op.args - @assert length(args) == 2 - @assert args[2]._op.id == idx - 1 - a = available_activations[string(args[1])] + if length(args) == 1 + @assert args[1]._op.id == idx - 1 + act = op.fn + elseif length(args) == 2 + @assert args[2]._op.id == idx - 1 + act = args[1] + else + @assert false "cannot parse activation $op" + end + a = available_activations[string(act)] idx += 1 end diff --git a/test/FileFormats/ONNX.jl b/test/FileFormats/ONNX.jl index 4cac0c1..6a346fe 100644 --- a/test/FileFormats/ONNX.jl +++ b/test/FileFormats/ONNX.jl @@ -2,10 +2,22 @@ file = joinpath(@__DIR__, "sample_ONNX.onnx") # parse file -N = read_ONNX(file); +N = read_ONNX(file) # alternative parse with optional argument -N2 = read_ONNX(file; input_dimension=6); +N2 = read_ONNX(file; input_dimension=6) @test N == N2 @test length(N.layers) == 4 + +# alternative file with different activation encoding +file = joinpath(@__DIR__, "sample_ONNX2.onnx") + +# parse file +N = read_ONNX(file) + +# alternative parse with optional argument +N2 = read_ONNX(file; input_dimension=4) +@test N == N2 + +@test length(N.layers) == 3 diff --git a/test/FileFormats/sample_ONNX2.onnx b/test/FileFormats/sample_ONNX2.onnx new file mode 100644 index 0000000..c837ab8 Binary files /dev/null and b/test/FileFormats/sample_ONNX2.onnx differ