From 555b7b895709e57e4d50416c8bb03de9092d854e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 3 Nov 2021 00:40:59 +0100 Subject: [PATCH 01/13] Use scalar parameters with ParameterHandling --- Project.toml | 2 + src/KernelFunctions.jl | 4 + src/basekernels/constant.jl | 28 +++-- src/basekernels/cosine.jl | 2 + src/basekernels/exponential.jl | 24 +++-- src/basekernels/exponentiated.jl | 2 + src/basekernels/fbm.jl | 14 ++- src/basekernels/matern.jl | 20 +++- src/basekernels/nn.jl | 2 + src/basekernels/periodic.jl | 15 ++- src/basekernels/piecewisepolynomial.jl | 2 + src/basekernels/polynomial.jl | 43 +++++--- src/basekernels/rational.jl | 69 ++++++++---- src/basekernels/wiener.jl | 2 + src/kernels/gibbskernel.jl | 10 ++ src/kernels/kernelproduct.jl | 19 ++++ src/kernels/kernelsum.jl | 20 ++++ src/kernels/kerneltensorproduct.jl | 19 ++++ src/kernels/normalizedkernel.jl | 6 ++ src/kernels/parameterkernel.jl | 142 +++++++++++++++++++++++++ src/kernels/scaledkernel.jl | 33 ++++-- src/kernels/transformedkernel.jl | 15 +++ src/mokernels/independent.jl | 8 ++ src/mokernels/intrinsiccoregion.jl | 16 +++ src/mokernels/lmm.jl | 26 +++++ src/test_utils.jl | 4 + src/transform/ardtransform.jl | 6 +- src/transform/chaintransform.jl | 19 ++++ src/transform/lineartransform.jl | 6 +- src/transform/periodic_transform.jl | 22 ++-- src/transform/scaletransform.jl | 34 +++--- src/transform/transform.jl | 2 + src/utils.jl | 17 +++ 33 files changed, 560 insertions(+), 93 deletions(-) create mode 100644 src/kernels/parameterkernel.jl diff --git a/Project.toml b/Project.toml index d3a56f055..9f286dfb2 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -29,6 +30,7 @@ FillArrays = "0.10, 0.11, 0.12" Functors = "0.1, 0.2" IrrationalConstants = "0.1" LogExpFunctions = "0.2.1, 0.3" +ParameterHandling = "0.4" Requires = "1.0.1" SpecialFunctions = "0.8, 0.9, 0.10, 1" StatsBase = "0.32, 0.33" diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 37bde4a65..a9f9afa8e 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -41,6 +41,8 @@ export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_ou export IndependentMOKernel, LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel +export ParameterKernel + # Reexports export tensor, ⊗, compose @@ -51,6 +53,7 @@ using CompositionsBase using Distances using FillArrays using Functors +using ParameterHandling using LinearAlgebra using Requires using SpecialFunctions: loggamma, besselk, polygamma @@ -107,6 +110,7 @@ include("kernels/kernelproduct.jl") include("kernels/kerneltensorproduct.jl") include("kernels/overloads.jl") include("kernels/neuralkernelnetwork.jl") +include("kernels/parameterkernel.jl") include("approximations/nystrom.jl") include("generic.jl") diff --git a/src/basekernels/constant.jl b/src/basekernels/constant.jl index 5996546f1..21879a9be 100644 --- a/src/basekernels/constant.jl +++ b/src/basekernels/constant.jl @@ -15,7 +15,9 @@ See also: [`ConstantKernel`](@ref) """ struct ZeroKernel <: SimpleKernel end -kappa(κ::ZeroKernel, d::T) where {T<:Real} = zero(T) +@noparams ZeroKernel + +kappa(::ZeroKernel, d::Real) = zero(d) metric(::ZeroKernel) = Delta() @@ -35,6 +37,8 @@ k(x, x') = \\delta(x, x'). """ struct WhiteKernel <: SimpleKernel end +@noparams WhiteKernel + """ EyeKernel() @@ -62,19 +66,27 @@ k(x, x') = c. See also: [`ZeroKernel`](@ref) """ -struct ConstantKernel{Tc<:Real} <: SimpleKernel - c::Vector{Tc} +struct ConstantKernel{T<:Real} <: SimpleKernel + c::T - function ConstantKernel(; c::Real=1.0) + function ConstantKernel(c::Real) @check_args(ConstantKernel, c, c >= zero(c), "c ≥ 0") - return new{typeof(c)}([c]) + return new{typeof(c)}(c) end end -@functor ConstantKernel +ConstantKernel(; c::Real=1.0) = ConstantKernel(c) + +function ParameterHandling.flatten(::Type{T}, k::ConstantKernel{S}) where {T<:Real,S} + function unflatten_to_constantkernel(v::Vector{T}) + length(v) == 1 || error("incorrect number of parameters") + ConstantKernel(; c=S(exp(first(v)))) + end + return T[log(k.c)], unflatten_to_constantkernel +end -kappa(κ::ConstantKernel, x::Real) = first(κ.c) * one(x) +kappa(κ::ConstantKernel, x::Real) = κ.c * one(x) metric(::ConstantKernel) = Delta() -Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", first(κ.c), ")") +Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", κ.c, ")") diff --git a/src/basekernels/cosine.jl b/src/basekernels/cosine.jl index 50cc6fdf3..4c1040822 100644 --- a/src/basekernels/cosine.jl +++ b/src/basekernels/cosine.jl @@ -17,6 +17,8 @@ end CosineKernel(; metric=Euclidean()) = CosineKernel(metric) +@noparams CosineKernel + kappa(::CosineKernel, d::Real) = cospi(d) metric(k::CosineKernel) = k.metric diff --git a/src/basekernels/exponential.jl b/src/basekernels/exponential.jl index c7a788b8a..f535df259 100644 --- a/src/basekernels/exponential.jl +++ b/src/basekernels/exponential.jl @@ -20,6 +20,8 @@ end SqExponentialKernel(; metric=Euclidean()) = SqExponentialKernel(metric) +@noparams SqExponentialKernel + kappa(::SqExponentialKernel, d::Real) = exp(-d^2 / 2) kappa(::SqExponentialKernel{<:Euclidean}, d²::Real) = exp(-d² / 2) @@ -76,6 +78,8 @@ end ExponentialKernel(; metric=Euclidean()) = ExponentialKernel(metric) +@noparams ExponentialKernel + kappa(::ExponentialKernel, d::Real) = exp(-d) metric(k::ExponentialKernel) = k.metric @@ -121,13 +125,13 @@ See also: [`ExponentialKernel`](@ref), [`SqExponentialKernel`](@ref) [^RW]: C. E. Rasmussen & C. K. I. Williams (2006). Gaussian Processes for Machine Learning. """ -struct GammaExponentialKernel{Tγ<:Real,M} <: SimpleKernel - γ::Vector{Tγ} +struct GammaExponentialKernel{T<:Real,M} <: SimpleKernel + γ::T metric::M function GammaExponentialKernel(γ::Real, metric) @check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") - return new{typeof(γ),typeof(metric)}([γ], metric) + return new{typeof(γ),typeof(metric)}(γ, metric) end end @@ -135,9 +139,17 @@ function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclid return GammaExponentialKernel(γ, metric) end -@functor GammaExponentialKernel +function ParameterHandling.flatten(::Type{T}, k::GammaExponentialKernel{S}) where {T<:Real,S<:Real} + metric = k.metric + function unflatten_to_gammaexponentialkernel(v::Vector{T}) + length(v) == 1 || error("incorrect number of parameters") + γ = S(1 + logistic(first(v))) + return GammaExponentialKernel(; γ=γ, metric=metric) + end + return T[logit(k.γ - 1)], unflatten_to_gammaexponentialkernel +end -kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^first(κ.γ)) +kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^κ.γ) metric(k::GammaExponentialKernel) = k.metric @@ -145,6 +157,6 @@ iskroncompatible(::GammaExponentialKernel) = true function Base.show(io::IO, κ::GammaExponentialKernel) return print( - io, "Gamma Exponential Kernel (γ = ", first(κ.γ), ", metric = ", κ.metric, ")" + io, "Gamma Exponential Kernel (γ = ", κ.γ, ", metric = ", κ.metric, ")" ) end diff --git a/src/basekernels/exponentiated.jl b/src/basekernels/exponentiated.jl index 0b360ceb6..66888f750 100644 --- a/src/basekernels/exponentiated.jl +++ b/src/basekernels/exponentiated.jl @@ -12,6 +12,8 @@ k(x, x') = \\exp(x^\\top x'). """ struct ExponentiatedKernel <: SimpleKernel end +@noparams ExponentiatedKernel + kappa(::ExponentiatedKernel, xᵀy::Real) = exp(xᵀy) metric(::ExponentiatedKernel) = DotProduct() diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index 213cb3c36..980747ff7 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -13,16 +13,24 @@ k(x, x'; h) = \\frac{\\|x\\|_2^{2h} + \\|x'\\|_2^{2h} - \\|x - x'\\|^{2h}}{2}. ``` """ struct FBMKernel{T<:Real} <: Kernel - h::Vector{T} + h::T + function FBMKernel(h::Real) @check_args(FBMKernel, h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]") - return new{typeof(h)}([h]) + return new{typeof(h)}(h) end end FBMKernel(; h::Real=0.5) = FBMKernel(h) -@functor FBMKernel +function ParameterHandling.flatten(::Type{T}, k::FBMKernel{S}) where {T<:Real,S<:Real} + function unflatten_to_fbmkernel(v::Vector{T}) + length(v) == 1 || error("incorrect number of parameters") + h = S((1 + logistic(first(v))) / 2) + return FBMKernel(h) + end + return T[logit(2 * k.h - 1)], unflatten_to_fbmkernel +end function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) modX = sum(abs2, x) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index a3c20efd7..ffb76ffdd 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -19,19 +19,27 @@ differentiable in the mean-square sense. See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref) """ -struct MaternKernel{Tν<:Real,M} <: SimpleKernel - ν::Vector{Tν} +struct MaternKernel{T<:Real,M} <: SimpleKernel + ν::T metric::M function MaternKernel(ν::Real, metric) @check_args(MaternKernel, ν, ν > zero(ν), "ν > 0") - return new{typeof(ν),typeof(metric)}([ν], metric) + return new{typeof(ν),typeof(metric)}(ν, metric) end end MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, metric) -@functor MaternKernel +function ParameterHandling.flatten(::Type{T}, k::MaternKernel{S}) where {T<:Real,S<:Real} + metric = k.metric + function unflatten_to_maternkernel(v::Vector{T}) + length(v) == 1 || error("incorrect number of parameters") + v = S(exp(first(v))) + return MaternKernel(ν, metric) + end + return T[log(k.ν)], unflatten_to_maternkernel +end @inline function kappa(κ::MaternKernel, d::Real) result = _matern(first(κ.ν), d) @@ -73,6 +81,8 @@ end Matern32Kernel(; metric=Euclidean()) = Matern32Kernel(metric) +@noparams Matern32Kernel + kappa(::Matern32Kernel, d::Real) = (1 + sqrt(3) * d) * exp(-sqrt(3) * d) metric(k::Matern32Kernel) = k.metric @@ -104,6 +114,8 @@ end Matern52Kernel(; metric=Euclidean()) = Matern52Kernel(metric) +@noparams Matern52Kernel + kappa(::Matern52Kernel, d::Real) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d) metric(k::Matern52Kernel) = k.metric diff --git a/src/basekernels/nn.jl b/src/basekernels/nn.jl index 40070075d..1e45c5a32 100644 --- a/src/basekernels/nn.jl +++ b/src/basekernels/nn.jl @@ -33,6 +33,8 @@ for inputs ``x, x' \\in \\mathbb{R}^d``.[^CW] """ struct NeuralNetworkKernel <: Kernel end +@noparams NeuralNetworkKernel + function (κ::NeuralNetworkKernel)(x, y) return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y)))) end diff --git a/src/basekernels/periodic.jl b/src/basekernels/periodic.jl index 2758d7f94..a0464f530 100644 --- a/src/basekernels/periodic.jl +++ b/src/basekernels/periodic.jl @@ -21,16 +21,25 @@ struct PeriodicKernel{T} <: SimpleKernel end end +""" + PeriodicKernel(dims::Int) + +Create a [`PeriodicKernel`](@ref) with parameter `r=ones(Float64, dims)`. +""" PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims) """ - PeriodicKernel([T=Float64, dims::Int=1]) + PeriodicKernel(T, dims::Int=1) Create a [`PeriodicKernel`](@ref) with parameter `r=ones(T, dims)`. """ -PeriodicKernel(T::DataType, dims::Int=1) = PeriodicKernel(; r=ones(T, dims)) +PeriodicKernel(::Type{T}, dims::Int=1) where {T} = PeriodicKernel(; r=ones(T, dims)) -@functor PeriodicKernel +function ParameterHandling.flatten(::Type{T}, k::PeriodicKernel{S}) where {T<:Real,S} + vec = T[log(ri) for ri in k.r] + unflatten_to_periodickernel(v::Vector{T}) = PeriodicKernel(; r=S[exp(vi) for vi in v]) + return vec, unflatten_to_periodickernel +end metric(κ::PeriodicKernel) = Sinus(κ.r) diff --git a/src/basekernels/piecewisepolynomial.jl b/src/basekernels/piecewisepolynomial.jl index 07b3638dd..39d8f7cf3 100644 --- a/src/basekernels/piecewisepolynomial.jl +++ b/src/basekernels/piecewisepolynomial.jl @@ -46,6 +46,8 @@ function PiecewisePolynomialKernel(; degree::Int=0, kwargs...) return PiecewisePolynomialKernel{degree}(; kwargs...) end +@noparams PiecewisePolynomialKernel + piecewise_polynomial_coefficients(::Val{0}, ::Int) = (1,) piecewise_polynomial_coefficients(::Val{1}, j::Int) = (1, j + 1) piecewise_polynomial_coefficients(::Val{2}, j::Int) = (1, j + 2, (j^2 + 4 * j)//3 + 1) diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index da686e2c9..8a42bb962 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -13,24 +13,30 @@ k(x, x'; c) = x^\\top x' + c. See also: [`PolynomialKernel`](@ref) """ -struct LinearKernel{Tc<:Real} <: SimpleKernel - c::Vector{Tc} +struct LinearKernel{T<:Real} <: SimpleKernel + c::T function LinearKernel(c::Real) @check_args(LinearKernel, c, c >= zero(c), "c ≥ 0") - return new{typeof(c)}([c]) + return new{typeof(c)}(c) end end LinearKernel(; c::Real=0.0) = LinearKernel(c) -@functor LinearKernel +function ParameterHandling.flatten(::Type{T}, k::LinearKernel{S}) where {T<:Real,S<:Real} + function unflatten_to_linearkernel(v::Vector{T}) + length(v) == 1 || error("incorrect number of parameters") + return LinearKernel(S(first(v))) + end + return T[k.c], unflatten_to_linearkernel +end -kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + first(κ.c) +kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + κ.c metric(::LinearKernel) = DotProduct() -Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first(κ.c), ")") +Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", κ.c, ")") """ PolynomialKernel(; degree::Int=2, c::Real=0.0) @@ -47,14 +53,14 @@ k(x, x'; c, \\nu) = (x^\\top x' + c)^\\nu. See also: [`LinearKernel`](@ref) """ -struct PolynomialKernel{Tc<:Real} <: SimpleKernel +struct PolynomialKernel{T<:Real} <: SimpleKernel degree::Int - c::Vector{Tc} + c::T - function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc} + function PolynomialKernel(degree::Int, c::Real) @check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1") - @check_args(PolynomialKernel, c, first(c) >= zero(Tc), "c ≥ 0") - return new{Tc}(degree, c) + @check_args(PolynomialKernel, c, c >= zero(c), "c ≥ 0") + return new{typeof(c)}(degree, c) end end @@ -62,16 +68,19 @@ function PolynomialKernel(; degree::Int=2, c::Real=0.0) return PolynomialKernel{typeof(c)}(degree, [c]) end -# The degree of the polynomial kernel is a fixed discrete parameter -function Functors.functor(::Type{<:PolynomialKernel}, x) - reconstruct_polynomialkernel(xs) = PolynomialKernel{typeof(xs.c)}(x.degree, xs.c) - return (c=x.c,), reconstruct_polynomialkernel +function ParameterHandling.flatten(::Type{T}, k::PolynomialKernel{S}) where {T<:Real,S<:Real} + degree = k.degree + function unflatten_to_polynomialkernel(v::Vector{T}) + length(v) == 1 || error("incorrect number of parameters") + return PolynomialKernel(degree, S(first(v))) + end + return T[k.c], unflatten_to_polynomialkernel end -kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + first(κ.c))^κ.degree +kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + κ.c)^κ.degree metric(::PolynomialKernel) = DotProduct() function Base.show(io::IO, κ::PolynomialKernel) - return print(io, "Polynomial Kernel (c = ", first(κ.c), ", degree = ", κ.degree, ")") + return print(io, "Polynomial Kernel (c = ", κ.c, ", degree = ", κ.degree, ")") end diff --git a/src/basekernels/rational.jl b/src/basekernels/rational.jl index 8ed396b51..69fd32fdd 100644 --- a/src/basekernels/rational.jl +++ b/src/basekernels/rational.jl @@ -15,13 +15,13 @@ The [`ExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to \\in See also: [`GammaRationalKernel`](@ref) """ -struct RationalKernel{Tα<:Real,M} <: SimpleKernel - α::Vector{Tα} +struct RationalKernel{T<:Real,M} <: SimpleKernel + α::T metric::M function RationalKernel(α::Real, metric) @check_args(RationalKernel, α, α > zero(α), "α > 0") - return new{typeof(α),typeof(metric)}([α], metric) + return new{typeof(α),typeof(metric)}(α, metric) end end @@ -29,16 +29,24 @@ function RationalKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean()) return RationalKernel(α, metric) end -@functor RationalKernel +function ParameterHandling.flatten(::Type{T}, k::RationalKernel{S}) where {T<:Real,S} + metric = k.metric + function unflatten_to_rationalkernel(v::Vector{T}) + length(v) == 1 || error("incorrect number of parameters") + return ConstantKernel(S(exp(first(v))), metric) + end + return T[log(k.α)], unflatten_to_rationalkernel +end function kappa(κ::RationalKernel, d::Real) - return (one(d) + d / first(κ.α))^(-first(κ.α)) + α = κ.α + return (one(d) + d / α)^(-α) end metric(k::RationalKernel) = k.metric function Base.show(io::IO, κ::RationalKernel) - return print(io, "Rational Kernel (α = ", first(κ.α), ", metric = ", κ.metric, ")") + return print(io, "Rational Kernel (α = ", κ.α, ", metric = ", κ.metric, ")") end """ @@ -59,23 +67,32 @@ The [`SqExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to \\ See also: [`GammaRationalKernel`](@ref) """ -struct RationalQuadraticKernel{Tα<:Real,M} <: SimpleKernel - α::Vector{Tα} +struct RationalQuadraticKernel{T<:Real,M} <: SimpleKernel + α::T metric::M function RationalQuadraticKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean()) @check_args(RationalQuadraticKernel, α, α > zero(α), "α > 0") - return new{typeof(α),typeof(metric)}([α], metric) + return new{typeof(α),typeof(metric)}(α, metric) end end -@functor RationalQuadraticKernel +function ParameterHandling.flatten(::Type{T}, k::RationalQuadraticKernel{S}) where {T<:Real,S} + metric = k.metric + function unflatten_to_rationalquadratickernel(v::Vector{T}) + length(v) == 1 || error("incorrect number of parameters") + return RationalQuadraticKernel(; α=S(exp(first(v))), metric=metric) + end + return T[log(k.α)], unflatten_to_rationalquadratickernel +end function kappa(κ::RationalQuadraticKernel, d::Real) - return (one(d) + d^2 / (2 * first(κ.α)))^(-first(κ.α)) + α = κ.α + return (one(d) + d^2 / (2 * α))^(-α) end function kappa(κ::RationalQuadraticKernel{<:Real,<:Euclidean}, d²::Real) - return (one(d²) + d² / (2 * first(κ.α)))^(-first(κ.α)) + α = κ.α + return (one(d²) + d² / (2 * α))^(-α) end metric(k::RationalQuadraticKernel) = k.metric @@ -83,7 +100,7 @@ metric(::RationalQuadraticKernel{<:Real,<:Euclidean}) = SqEuclidean() function Base.show(io::IO, κ::RationalQuadraticKernel) return print( - io, "Rational Quadratic Kernel (α = ", first(κ.α), ", metric = ", κ.metric, ")" + io, "Rational Quadratic Kernel (α = ", κ.α, ", metric = ", κ.metric, ")" ) end @@ -106,8 +123,8 @@ The [`GammaExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to See also: [`RationalKernel`](@ref), [`RationalQuadraticKernel`](@ref) """ struct GammaRationalKernel{Tα<:Real,Tγ<:Real,M} <: SimpleKernel - α::Vector{Tα} - γ::Vector{Tγ} + α::Tα + γ::Tγ metric::M function GammaRationalKernel(; @@ -115,14 +132,28 @@ struct GammaRationalKernel{Tα<:Real,Tγ<:Real,M} <: SimpleKernel ) @check_args(GammaRationalKernel, α, α > zero(α), "α > 0") @check_args(GammaRationalKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") - return new{typeof(α),typeof(γ),typeof(metric)}([α], [γ], metric) + return new{typeof(α),typeof(γ),typeof(metric)}(α, γ, metric) end end @functor GammaRationalKernel +function ParameterHandling.flatten(::Type{T}, k::GammaRationalKernel{Tα,Tγ}) where {T<:Real,Tα,Tγ} + vec = T[log(k.α), logit(k.γ - 1)] + metric = k.metric + function unflatten_to_gammarationalkernel(v::Vector{T}) + length(v) == 2 || error("incorrect number of parameters") + logα, logitγ = v + α = Tα(exp(logα)) + γ = Tγ(1 + logistic(logitγ)) + return GammaRationalKernel(; α=α, γ=γ, metric=metric) + end + return vec, unflatten_to_gammarationalkernel +end + function kappa(κ::GammaRationalKernel, d::Real) - return (one(d) + d^first(κ.γ) / first(κ.α))^(-first(κ.α)) + α = κ.α + return (one(d) + d^κ.γ / α)^(-α) end metric(k::GammaRationalKernel) = k.metric @@ -131,9 +162,9 @@ function Base.show(io::IO, κ::GammaRationalKernel) return print( io, "Gamma Rational Kernel (α = ", - first(κ.α), + κ.α, ", γ = ", - first(κ.γ), + κ.γ, ", metric = ", κ.metric, ")", diff --git a/src/basekernels/wiener.jl b/src/basekernels/wiener.jl index 14d330850..741a5960f 100644 --- a/src/basekernels/wiener.jl +++ b/src/basekernels/wiener.jl @@ -52,6 +52,8 @@ function WienerKernel(; i::Integer=0) return WienerKernel{i}() end +@noparams WienerKernel + function (::WienerKernel{0})(x, y) X = sqrt(sum(abs2, x)) Y = sqrt(sum(abs2, y)) diff --git a/src/kernels/gibbskernel.jl b/src/kernels/gibbskernel.jl index 46e14995d..d1bd9f4ec 100644 --- a/src/kernels/gibbskernel.jl +++ b/src/kernels/gibbskernel.jl @@ -36,6 +36,16 @@ end GibbsKernel(; lengthscale) = GibbsKernel(lengthscale) +@functor GibbsKernel + +# or just `@noparams GibbsKernel` - it would be safer since there is no +# default fallback for `flatten` +function ParameterHandling.flatten(::Type{T}, k::GibbsKernel) where {T<:Real} + vec, unflatten_to_lengthscale = flatten(T, k.lengthscale) + unflatten_to_gibbskernel(v::Vector{T}) = GibbsKernel(unflatten_to_lengthscale(v)) + return vec, unflatten_to_gibbskernel +end + function (k::GibbsKernel)(x, y) lengthscale = k.lengthscale lx = lengthscale(x) diff --git a/src/kernels/kernelproduct.jl b/src/kernels/kernelproduct.jl index 990b4a1bb..476843b13 100644 --- a/src/kernels/kernelproduct.jl +++ b/src/kernels/kernelproduct.jl @@ -41,6 +41,25 @@ end @functor KernelProduct +function ParameterHandling.flatten(::Type{T}, k::KernelProduct) where {T<:Real} + vecs_and_backs = map(Base.Fix1(flatten, T), k.kernels) + vecs = map(first, vecs_and_backs) + length_vecs = map(length, vecs) + backs = map(last, vecs_and_backs) + flat_vecs = reduce(vcat, vecs) + function unflatten_to_kernelproduct(v::Vector{T}) + length(v) == length(flat_vecs) || error("incorrect number of parameters") + offset = Ref(0) + kernels = map(backs, length_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + return KernelProduct(kernels) + end + return flat_vecs, unflatten_to_kernelproduct +end + Base.length(k::KernelProduct) = length(k.kernels) (κ::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels) diff --git a/src/kernels/kernelsum.jl b/src/kernels/kernelsum.jl index 6c4c8d499..cc9f51b73 100644 --- a/src/kernels/kernelsum.jl +++ b/src/kernels/kernelsum.jl @@ -41,6 +41,26 @@ end @functor KernelSum +function ParameterHandling.flatten(::Type{T}, k::KernelSum) where {T<:Real} + vecs_and_backs = map(Base.Fix1(flatten, T), k.kernels) + vecs = map(first, vecs_and_backs) + length_vecs = map(length, vecs) + backs = map(last, vecs_and_backs) + flat_vecs = reduce(vcat, vecs) + n = length(flat_vecs) + function unflatten_to_kernelsum(v::Vector{T}) + length(v) == n || error("incorrect number of parameters") + offset = Ref(0) + kernels = map(backs, length_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + return KernelSum(kernels) + end + return flat_vecs, unflatten_to_kernelsum +end + Base.length(k::KernelSum) = length(k.kernels) (κ::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels) diff --git a/src/kernels/kerneltensorproduct.jl b/src/kernels/kerneltensorproduct.jl index ea0044409..ab3f3486a 100644 --- a/src/kernels/kerneltensorproduct.jl +++ b/src/kernels/kerneltensorproduct.jl @@ -47,6 +47,25 @@ end @functor KernelTensorProduct +function ParameterHandling.flatten(::Type{T}, k::KernelTensorProduct) where {T<:Real} + vecs_and_backs = map(Base.Fix1(flatten, T), k.kernels) + vecs = map(first, vecs_and_backs) + length_vecs = map(length, vecs) + backs = map(last, vecs_and_backs) + flat_vecs = reduce(vcat, vecs) + function unflatten_to_kerneltensorproduct(v::Vector{T}) + length(v) == length(flat_vecs) || error("incorrect number of parameters") + offset = Ref(0) + kernels = map(backs, length_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + return KernelTensorProduct(kernels) + end + return flat_vecs, unflatten_to_kerneltensorproduct +end + Base.length(kernel::KernelTensorProduct) = length(kernel.kernels) function (kernel::KernelTensorProduct)(x, y) diff --git a/src/kernels/normalizedkernel.jl b/src/kernels/normalizedkernel.jl index 5644f11a0..17591bc24 100644 --- a/src/kernels/normalizedkernel.jl +++ b/src/kernels/normalizedkernel.jl @@ -17,6 +17,12 @@ end @functor NormalizedKernel +function ParameterHandling.flatten(::Type{T}, k::NormalizedKernel) where {T<:Real} + vec, back = flatten(T, k.kernel) + unflatten_to_normalizedkernel(v::Vector{T}) = NormalizedKernel(back(v)) + return vec, unflatten_to_normalizedkernel +end + (κ::NormalizedKernel)(x, y) = κ.kernel(x, y) / sqrt(κ.kernel(x, x) * κ.kernel(y, y)) function kernelmatrix(κ::NormalizedKernel, x::AbstractVector, y::AbstractVector) diff --git a/src/kernels/parameterkernel.jl b/src/kernels/parameterkernel.jl new file mode 100644 index 000000000..0f925abe9 --- /dev/null +++ b/src/kernels/parameterkernel.jl @@ -0,0 +1,142 @@ +""" + ParameterKernel(params, kernel) + +Kernel with parameters `params` that can be instantiated by calling `kernel(params)`. + +This kernel is particularly useful if you want to optimize a vector of, +usually unconstrained, kernel parameters `params` with e.g. +[Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl) or +[Flux.jl](https://github.com/FluxML/Flux.jl). + +# Examples + +There are two different approaches for obtaining the parameters `params` and the function +`kernel` from which a `ParameterKernel` can be constructed. + +## Extracting parameters from an existing kernel + +You can extract the parameters `params` and the function `kernel` from an existing kernel +`k` with `ParameterHandling.flatten`: +```jldoctest parameterkernel1 +julia> k = 2.0 * (RationalQuadraticKernel(; α=1.0) + ConstantKernel(; c=2.5)); + +julia> params, kernel = ParameterHandling.flatten(k); +``` + +Here `params` is a vector of the three parameters of kernel `k`. In this example, all these +parameters must be positive (otherwise `k` would not be a positive-definite kernel). To +simplify unconstrained optimization with e.g. Optim.jl or Flux.jl, +`ParameterHandling.flatten` automatically transforms the parameters to unconstrained values: +```jldoctest parameterkernel1 +julia> params ≈ map(log, [1.0, 2.5, 2.0]) +true +``` + +Kernel `k` can be reconstructed with the `kernel` function: +```jldoctest parameterkernel1 +julia> kernel(params) +Sum of 2 kernels: + Rational Quadratic Kernel (α = 1.0, metric = Distances.Euclidean(0.0)) + Constant Kernel (c = 2.5) + - σ² = 2.0 +``` + +As expected, different parameter values yield a kernel of the same structure with different +parameters: +```jldoctest parameterkernel1 +julia> kernel([log(0.25), log(0.5), log(2.0)]) +Sum of 2 kernels: + Rational Quadratic Kernel (α = 0.25, metric = Distances.Euclidean(0.0)) + Constant Kernel (c = 0.5) + - σ² = 2.0 +``` + +## Defining a function that constructs the kernel + +Instead of extracting parameters and a reconstruction function from an existing kernel you +can explicitly define a function that constructs the kernel of interest and a set of +parameters. + +```jldoctest parameterkernel2 +julia> using LogExpFunctions + +julia> function kernel(params) + length(params) == 1 || throw(ArgumentError("incorrect number of parameters")) + p = first(params) + return 2 * (RationalQuadraticKernel(; α=log1pexp(p)) + ConstantKernel(; c=exp(p))) + end; +``` + +With the function `kernel` kernels of the same structure as in the example above can be +constructed: +```jldoctest parameterkernel2 +julia> kernel([log(0.5)]) +Sum of 2 kernels: + Rational Quadratic Kernel (α = 0.4054651081081644, metric = Distances.Euclidean(0.0)) + Constant Kernel (c = 0.5) + - σ² = 2 +``` + +This example shows that defining `kernel` manually has some advantages over using +`ParameterHandling.flatten`: +- Kernel parameters can be fixed (scale parameter is always set to `2` in this example) +- Kernel parameters can be transformed from unconstrained to constrained space with + non-default mappings (shape parameter `α` is transformed with `log1pexp`) +- Kernel parameters can be linked (parameters `α` and `c` are computed from a single + parameter `p`) + +See also: [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl) +""" +struct ParameterKernel{P,K} <: Kernel + params::P + kernel::K +end + +Functors.@functor ParameterKernel (params,) + +function ParameterHandling.flatten(::Type{T}, kernel::ParameterKernel) where {T<:Real} + params_vec, unflatten_to_params = flatten(T, kernel.params) + k = kernel.kernel + function unflatten_to_parameterkernel(v::Vector{T}) + return ParameterKernel(unflatten_to_params(v), k) + end + return params_vec, unflatten_to_parameterkernel +end + +(k::ParameterKernel)(x, y) = k.kernel(k.params)(x, y) + +function kernelmatrix(k::ParameterKernel, x::AbstractVector) + return kernelmatrix(k.kernel(k.params), x) +end + +function kernelmatrix(k::ParameterKernel, x::AbstractVector, y::AbstractVector) + return kernelmatrix(k.kernel(k.params), x, y) +end + +function kernelmatrix!(K::AbstractMatrix, k::ParameterKernel, x::AbstractVector) + return kernelmatrix!(K, k.kernel(k.params), x) +end + +function kernelmatrix!( + K::AbstractMatrix, k::ParameterKernel, x::AbstractVector, y::AbstractVector +) + return kernelmatrix!(K, k.kernel(k.params), x, y) +end + +function kernelmatrix_diag(k::ParameterKernel, x::AbstractVector) + return kernelmatrix_diag(k.kernel(k.params), x) +end + +function kernelmatrix_diag(k::ParameterKernel, x::AbstractVector, y::AbstractVector) + return kernelmatrix_diag(k.kernel(k.params), x, y) +end + +function kernelmatrix_diag!(K::AbstractVector, k::ParameterKernel, x::AbstractVector) + return kernelmatrix_diag!(K, k.kernel(k.params), x) +end + +function kernelmatrix_diag!( + K::AbstractVector, k::ParameterKernel, x::AbstractVector, y::AbstractVector +) + return kernelmatrix_diag!(K, k.kernel(k.params), x, y) +end diff --git a/src/kernels/scaledkernel.jl b/src/kernels/scaledkernel.jl index 897bdda1a..4361064f7 100644 --- a/src/kernels/scaledkernel.jl +++ b/src/kernels/scaledkernel.jl @@ -13,17 +13,30 @@ multiplication with variance ``\\sigma^2 > 0`` is defined as """ struct ScaledKernel{Tk<:Kernel,Tσ²<:Real} <: Kernel kernel::Tk - σ²::Vector{Tσ²} -end + σ²::Tσ² -function ScaledKernel(kernel::Tk, σ²::Tσ²=1.0) where {Tk<:Kernel,Tσ²<:Real} - @check_args(ScaledKernel, σ², σ² > zero(Tσ²), "σ² > 0") - return ScaledKernel{Tk,Tσ²}(kernel, [σ²]) + function ScaledKernel(kernel::Kernel, σ²::Real) + @check_args(ScaledKernel, σ², σ² > zero(σ²), "σ² > 0") + return new{typeof(kernel),typeof(σ²)}(kernel, σ²) + end end -@functor ScaledKernel +ScaledKernel(kernel::Kernel) = ScaledKernel(kernel, 1.0) + +# σ² is a positive parameter (and a scalar!) but Functors does not handle +# parameter constraints +@functor ScaledKernel (kernel,) + +function ParameterHandling.flatten(::Type{T}, k::ScaledKernel{<:Kernel,S}) where {T<:Real,S<:Real} + kernel_vec, kernel_back = flatten(T, k.kernel) + function unflatten_to_scaledkernel(v::Vector{T}) + kernel = kernel_back(v[1:end-1]) + return ScaledKernel(kernel, S(exp(last(v)))) + end + return vcat(kernel_vec, T(log(k.σ²))), unflatten_to_scaledkernel +end -(k::ScaledKernel)(x, y) = first(k.σ²) * k.kernel(x, y) +(k::ScaledKernel)(x, y) = k.σ² * k.kernel(x, y) function kernelmatrix(κ::ScaledKernel, x::AbstractVector, y::AbstractVector) return κ.σ² .* kernelmatrix(κ.kernel, x, y) @@ -75,5 +88,9 @@ Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0) function printshifted(io::IO, κ::ScaledKernel, shift::Int) printshifted(io, κ.kernel, shift) - return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(first(κ.σ²))") + print(io, "\n") + for _ in 1:(shift + 1) + print(io, "\t") + end + print(io, "- σ² = ", κ.σ²) end diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 94ae5c147..57c2a1d0c 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -16,6 +16,21 @@ end @functor TransformedKernel +function ParameterHandling.flatten(::Type{T}, k::TransformedKernel) where {T<:Real} + kernel_vec, kernel_back = flatten(T, k.kernel) + transform_vec, transform_back = flatten(T, k.transform) + v = vcat(kernel_vec, transform_vec) + n = length(v) + nkernel = length(kernel_vec) + function unflatten_to_transformedkernel(v::Vector{T}) + length(v) == n || error("incorrect number of parameters") + kernel = kernel_back(v[1:nkernel]) + transform = transform_back(v[(nkernel + 1):end]) + return TransformedKernel(kernel, transform) + end + return v, unflatten_to_transformedkernel +end + (k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y)) # Optimizations for scale transforms of simple kernels to save allocations: diff --git a/src/mokernels/independent.jl b/src/mokernels/independent.jl index 1f7811b14..d722146c0 100644 --- a/src/mokernels/independent.jl +++ b/src/mokernels/independent.jl @@ -23,6 +23,14 @@ struct IndependentMOKernel{Tkernel<:Kernel} <: MOKernel kernel::Tkernel end +@functor IndependentMOKernel + +function ParameterHandling.flatten(::Type{T}, k::IndependentMOKernel) where {T<:Real} + vec, unflatten_to_kernel = flatten(T, k.kernel) + unflatten_to_independentmokernel(v::Vector{T}) = IndependentMOKernel(unflatten_to_kernel(v)) + return vec, unflatten_to_independentmokernel +end + function (κ::IndependentMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int}) return κ.kernel(x, y) * (px == py) end diff --git a/src/mokernels/intrinsiccoregion.jl b/src/mokernels/intrinsiccoregion.jl index 0a940796b..223485422 100644 --- a/src/mokernels/intrinsiccoregion.jl +++ b/src/mokernels/intrinsiccoregion.jl @@ -38,6 +38,22 @@ function IntrinsicCoregionMOKernel(kernel::Kernel, B::AbstractMatrix) return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B) end +@functor IntrinsicCoregionMOKernel (kernel,) + +function ParameterHandling.flatten(::Type{T}, k::IntrinsicCoregionMOKernel) where {T<:Real} + kernel_vec, unflatten_to_kernel = flatten(T, k.kernel) + B_vec, unflatten_to_B = value_flatten(T, positive_definite(k.B)) + nkernel = length(kernel_vec) + ntotal = nkernel + length(B_vec) + function unflatten_to_intrinsiccoregionkernel(v::Vector{T}) + length(v) == ntotal || error("incorrect number of parameters") + kernel = unflatten_to_kernel(v[1:nkernel]) + B = unflatten_to_B(v[(nkernel + 1):end]) + return IntrinsicCoregionMOKernel(kernel, B) + end + return vcat(kernel_vec, B_vec), unflatten_to_intrinsiccoregionkernel +end + function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int}) return k.B[px, py] * k.kernel(x, y) end diff --git a/src/mokernels/lmm.jl b/src/mokernels/lmm.jl index 0045d520e..8e37114e7 100644 --- a/src/mokernels/lmm.jl +++ b/src/mokernels/lmm.jl @@ -31,6 +31,32 @@ function LinearMixingModelKernel(k::Kernel, H::AbstractMatrix) return LinearMixingModelKernel(Fill(k, size(H, 1)), H) end +@functor LinearMixingModelKernel + +function ParameterHandling.flatten(::Type{T}, k::LinearMixingModelKernel) where {T<:Real} + kernel_vecs_and_backs = map(Base.Fix1(flatten, T), k.kernels) + kernel_vecs = map(first, kernel_vecs_and_backs) + length_kernel_vecs = map(length, kernel_vecs) + kernel_backs = map(last, kernel_vecs_and_backs) + H_vec, H_back = flatten(T, k.B) + flat_kernel_vecs = reduce(vcat, vecs) + nkernel = length(flat_kernel_vecs) + flat_vecs = vcat(flat_kernel_vecs, H_vec) + n = length(flat_vecs) + function unflatten_to_linearmixingmodelkernel(v::Vector{T}) + length(v) == n || error("incorrect number of parameters") + offset = Ref(0) + kernels = map(kernel_backs, length_kernel_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + H = H_back(v[(nkernel + 1):end]) + return LinearMixingModelKernel(kernels, H) + end + return flat_vecs, unflatten_to_linearmixingmodelkernel +end + function (κ::LinearMixingModelKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int}) (px > size(κ.H, 2) || py > size(κ.H, 2) || px < 1 || py < 1) && error("`px` and `py` must be within the range of the number of outputs") diff --git a/src/test_utils.jl b/src/test_utils.jl index 6ebb1d068..b791fa084 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,6 +6,7 @@ const __RTOL = 1e-9 using Distances using LinearAlgebra using KernelFunctions +using ParameterHandling using Random using Test @@ -88,6 +89,9 @@ function test_interface( tmp_diag = Vector{Float64}(undef, length(x0)) @test kernelmatrix_diag!(tmp_diag, k, x0) ≈ kernelmatrix_diag(k, x0) @test kernelmatrix_diag!(tmp_diag, k, x0, x1) ≈ kernelmatrix_diag(k, x0, x1) + + # Check flatten/unflatten + test_flatten_interface(k) end function test_interface( diff --git a/src/transform/ardtransform.jl b/src/transform/ardtransform.jl index 4e71d0141..2f83ace09 100644 --- a/src/transform/ardtransform.jl +++ b/src/transform/ardtransform.jl @@ -23,7 +23,11 @@ Create an [`ARDTransform`](@ref) with vector `fill(s, dims)`. """ ARDTransform(s::Real, dims::Integer) = ARDTransform(fill(s, dims)) -@functor ARDTransform +function ParameterHandling.flatten(::Type{T}, t::ARDTransform) where {T<:Real} + vec, back = flatten(T, t.v) + unflatten_to_ardtransform(v::Vector{T}) = ARDTransform(back(v)) + return vec, unflatten_to_ardtransform +end function set!(t::ARDTransform{<:AbstractVector{T}}, ρ::AbstractVector{T}) where {T<:Real} @assert length(ρ) == dim(t) "Trying to set a vector of size $(length(ρ)) to ARDTransform of dimension $(dim(t))" diff --git a/src/transform/chaintransform.jl b/src/transform/chaintransform.jl index bd4627b19..32ac558cf 100644 --- a/src/transform/chaintransform.jl +++ b/src/transform/chaintransform.jl @@ -25,6 +25,25 @@ end @functor ChainTransform +function ParameterHandling.flatten(::Type{T}, t::ChainTransform) where {T<:Real} + vecs_and_backs = map(Base.Fix1(flatten, T), t.transforms) + vecs = map(first, vecs_and_backs) + length_vecs = map(length, vecs) + backs = map(last, vecs_and_backs) + flat_vecs = reduce(vcat, vecs) + function unflatten_to_chaintransform(v::Vector{T}) + length(v) == length(flat_vecs) || error("incorrect number of parameters") + offset = Ref(0) + transforms = map(backs, length_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + return ChainTransform(transforms) + end + return flat_vecs, unflatten_to_chaintransform +end + Base.length(t::ChainTransform) = length(t.transforms) # Constructor to create a chain transform with an array of parameters diff --git a/src/transform/lineartransform.jl b/src/transform/lineartransform.jl index b61ba6a94..620f99c02 100644 --- a/src/transform/lineartransform.jl +++ b/src/transform/lineartransform.jl @@ -18,7 +18,11 @@ struct LinearTransform{T<:AbstractMatrix{<:Real}} <: Transform A::T end -@functor LinearTransform +function ParameterHandling.flatten(::Type{T}, t::LinearTransform) where {T<:Real} + vec, back = flatten(T, t.A) + unflatten_to_lineartransform(v::Vector{T}) = LinearTransform(back(v)) + return vec, unflatten_to_lineartransform +end function set!(t::LinearTransform{<:AbstractMatrix{T}}, A::AbstractMatrix{T}) where {T<:Real} size(t.A) == size(A) || error( diff --git a/src/transform/periodic_transform.jl b/src/transform/periodic_transform.jl index 3430a63a1..115625a61 100644 --- a/src/transform/periodic_transform.jl +++ b/src/transform/periodic_transform.jl @@ -15,26 +15,26 @@ julia> t(x) == [sinpi(2 * f * x), cospi(2 * f * x)] true ``` """ -struct PeriodicTransform{Tf<:AbstractVector{<:Real}} <: Transform - f::Tf +struct PeriodicTransform{T<:Real} <: Transform + f::T end -@functor PeriodicTransform - -PeriodicTransform(f::Real) = PeriodicTransform([f]) +function ParameterHandling.flatten(::Type{T}, t::PeriodicTransform) where {T<:Real} + f = t.f + unflatten_to_periodictransform(v::Vector{T}) = PeriodicTransform(oftype(f, only(v))) + return T[f], unflatten_to_periodictransform +end dim(t::PeriodicTransform) = 2 -(t::PeriodicTransform)(x::Real) = [sinpi(2 * first(t.f) * x), cospi(2 * first(t.f) * x)] +(t::PeriodicTransform)(x::Real) = [sinpi(2 * t.f * x), cospi(2 * t.f * x)] function _map(t::PeriodicTransform, x::AbstractVector{<:Real}) - return RowVecs(hcat(sinpi.((2 * first(t.f)) .* x), cospi.((2 * first(t.f)) .* x))) + return RowVecs(hcat(sinpi.((2 * t.f) .* x), cospi.((2 * t.f) .* x))) end -function Base.isequal(t1::PeriodicTransform, t2::PeriodicTransform) - return isequal(first(t1.f), first(t2.f)) -end +Base.isequal(t1::PeriodicTransform, t2::PeriodicTransform) = isequal(t1.f, t2.f) function Base.show(io::IO, t::PeriodicTransform) - return print(io, "Periodic Transform with frequency $(first(t.f))") + return print(io, "Periodic Transform with frequency ", t.f) end diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index 4cd1c5443..dc714cc0f 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -13,23 +13,33 @@ true ``` """ struct ScaleTransform{T<:Real} <: Transform - s::Vector{T} -end + s::T -function ScaleTransform(s::T=1.0) where {T<:Real} - return ScaleTransform{T}([s]) + function ScaleTransform(s::Real) + @check_args(ScaleTransform, s > zero(s), "s > 0") + return new{typeof(s)}(s) + end end -@functor ScaleTransform +ScaleTransform() = ScaleTransform(1.0) + -set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ] -(t::ScaleTransform)(x) = first(t.s) * x +function ParameterHandling.flatten(::Type{T}, t::ScaleTransform{S}) where {T<:Real,S<:Real} + s = t.s + function unflatten_to_scaletransform(v::Vector{T}) + length(v) == 1 || error("incorrect number of parameters") + ScaleTransform(S(first(v))) + end + return T[s], unflatten_to_scaletransform +end + +(t::ScaleTransform)(x) = t.s * x -_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x -_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X) -_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X) +_map(t::ScaleTransform, x::AbstractVector{<:Real}) = t.s .* x +_map(t::ScaleTransform, x::ColVecs) = ColVecs(t.s .* x.X) +_map(t::ScaleTransform, x::RowVecs) = RowVecs(t.s .* x.X) -Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(first(t.s), first(t2.s)) +Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(t.s, t2.s) -Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", first(t.s), ")") +Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", t.s, ")") diff --git a/src/transform/transform.jl b/src/transform/transform.jl index 40ce8c058..ac1da565e 100644 --- a/src/transform/transform.jl +++ b/src/transform/transform.jl @@ -15,6 +15,8 @@ Transformation that returns exactly the input. """ struct IdentityTransform <: Transform end +@noparams IdentityTransform + (t::IdentityTransform)(x) = x _map(::IdentityTransform, x::AbstractVector) = x diff --git a/src/utils.jl b/src/utils.jl index 7eea4358c..9e080cdc2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -221,3 +221,20 @@ end function validate_inplace_dims(K::AbstractVecOrMat, x::AbstractVector) return validate_inplace_dims(K, x, x) end + +# TODO: move to ParameterHandling? +""" + @noparams T + +Define `ParameterHandling.flatten` for a type `T` without parameters. +""" +macro noparams(T) + return quote + Base.@__doc__ function ParameterHandling.flatten( + ::Type{S}, x::$(esc(T)) + ) where {S<:Real} + unflatten(v::Vector{S}) = x + return v, unflatten + end + end +end From 9f3269e5cdb0894e4c2f5ea9c2d52bf877713201 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 3 Nov 2021 01:26:52 +0100 Subject: [PATCH 02/13] Some fixes --- src/basekernels/constant.jl | 2 +- src/basekernels/exponential.jl | 8 ++++---- src/basekernels/polynomial.jl | 4 +++- src/basekernels/rational.jl | 21 ++++++++------------- src/kernels/scaledkernel.jl | 8 +++++--- src/mokernels/independent.jl | 4 +++- src/test_utils.jl | 6 ++++-- src/transform/scaletransform.jl | 4 +--- test/approximations/nystrom.jl | 4 ++-- test/basekernels/fbm.jl | 2 +- test/basekernels/matern.jl | 2 +- test/basekernels/piecewisepolynomial.jl | 2 +- test/basekernels/wiener.jl | 4 ++-- test/kernels/gibbskernel.jl | 2 +- test/kernels/neuralkernelnetwork.jl | 4 ++-- test/kernels/transformedkernel.jl | 4 ++-- test/matrix/kernelmatrix.jl | 4 ++-- test/mokernels/independent.jl | 2 +- test/mokernels/intrinsiccoregion.jl | 6 +++--- test/utils.jl | 8 ++++---- 20 files changed, 51 insertions(+), 50 deletions(-) diff --git a/src/basekernels/constant.jl b/src/basekernels/constant.jl index 21879a9be..5df35a058 100644 --- a/src/basekernels/constant.jl +++ b/src/basekernels/constant.jl @@ -80,7 +80,7 @@ ConstantKernel(; c::Real=1.0) = ConstantKernel(c) function ParameterHandling.flatten(::Type{T}, k::ConstantKernel{S}) where {T<:Real,S} function unflatten_to_constantkernel(v::Vector{T}) length(v) == 1 || error("incorrect number of parameters") - ConstantKernel(; c=S(exp(first(v)))) + return ConstantKernel(; c=S(exp(first(v)))) end return T[log(k.c)], unflatten_to_constantkernel end diff --git a/src/basekernels/exponential.jl b/src/basekernels/exponential.jl index f535df259..4e7e375fb 100644 --- a/src/basekernels/exponential.jl +++ b/src/basekernels/exponential.jl @@ -139,7 +139,9 @@ function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclid return GammaExponentialKernel(γ, metric) end -function ParameterHandling.flatten(::Type{T}, k::GammaExponentialKernel{S}) where {T<:Real,S<:Real} +function ParameterHandling.flatten( + ::Type{T}, k::GammaExponentialKernel{S} +) where {T<:Real,S<:Real} metric = k.metric function unflatten_to_gammaexponentialkernel(v::Vector{T}) length(v) == 1 || error("incorrect number of parameters") @@ -156,7 +158,5 @@ metric(k::GammaExponentialKernel) = k.metric iskroncompatible(::GammaExponentialKernel) = true function Base.show(io::IO, κ::GammaExponentialKernel) - return print( - io, "Gamma Exponential Kernel (γ = ", κ.γ, ", metric = ", κ.metric, ")" - ) + return print(io, "Gamma Exponential Kernel (γ = ", κ.γ, ", metric = ", κ.metric, ")") end diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index 8a42bb962..aa0965573 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -68,7 +68,9 @@ function PolynomialKernel(; degree::Int=2, c::Real=0.0) return PolynomialKernel{typeof(c)}(degree, [c]) end -function ParameterHandling.flatten(::Type{T}, k::PolynomialKernel{S}) where {T<:Real,S<:Real} +function ParameterHandling.flatten( + ::Type{T}, k::PolynomialKernel{S} +) where {T<:Real,S<:Real} degree = k.degree function unflatten_to_polynomialkernel(v::Vector{T}) length(v) == 1 || error("incorrect number of parameters") diff --git a/src/basekernels/rational.jl b/src/basekernels/rational.jl index 69fd32fdd..55318ada9 100644 --- a/src/basekernels/rational.jl +++ b/src/basekernels/rational.jl @@ -77,7 +77,9 @@ struct RationalQuadraticKernel{T<:Real,M} <: SimpleKernel end end -function ParameterHandling.flatten(::Type{T}, k::RationalQuadraticKernel{S}) where {T<:Real,S} +function ParameterHandling.flatten( + ::Type{T}, k::RationalQuadraticKernel{S} +) where {T<:Real,S} metric = k.metric function unflatten_to_rationalquadratickernel(v::Vector{T}) length(v) == 1 || error("incorrect number of parameters") @@ -99,9 +101,7 @@ metric(k::RationalQuadraticKernel) = k.metric metric(::RationalQuadraticKernel{<:Real,<:Euclidean}) = SqEuclidean() function Base.show(io::IO, κ::RationalQuadraticKernel) - return print( - io, "Rational Quadratic Kernel (α = ", κ.α, ", metric = ", κ.metric, ")" - ) + return print(io, "Rational Quadratic Kernel (α = ", κ.α, ", metric = ", κ.metric, ")") end """ @@ -138,7 +138,9 @@ end @functor GammaRationalKernel -function ParameterHandling.flatten(::Type{T}, k::GammaRationalKernel{Tα,Tγ}) where {T<:Real,Tα,Tγ} +function ParameterHandling.flatten( + ::Type{T}, k::GammaRationalKernel{Tα,Tγ} +) where {T<:Real,Tα,Tγ} vec = T[log(k.α), logit(k.γ - 1)] metric = k.metric function unflatten_to_gammarationalkernel(v::Vector{T}) @@ -160,13 +162,6 @@ metric(k::GammaRationalKernel) = k.metric function Base.show(io::IO, κ::GammaRationalKernel) return print( - io, - "Gamma Rational Kernel (α = ", - κ.α, - ", γ = ", - κ.γ, - ", metric = ", - κ.metric, - ")", + io, "Gamma Rational Kernel (α = ", κ.α, ", γ = ", κ.γ, ", metric = ", κ.metric, ")" ) end diff --git a/src/kernels/scaledkernel.jl b/src/kernels/scaledkernel.jl index 4361064f7..becac13ac 100644 --- a/src/kernels/scaledkernel.jl +++ b/src/kernels/scaledkernel.jl @@ -27,10 +27,12 @@ ScaledKernel(kernel::Kernel) = ScaledKernel(kernel, 1.0) # parameter constraints @functor ScaledKernel (kernel,) -function ParameterHandling.flatten(::Type{T}, k::ScaledKernel{<:Kernel,S}) where {T<:Real,S<:Real} +function ParameterHandling.flatten( + ::Type{T}, k::ScaledKernel{<:Kernel,S} +) where {T<:Real,S<:Real} kernel_vec, kernel_back = flatten(T, k.kernel) function unflatten_to_scaledkernel(v::Vector{T}) - kernel = kernel_back(v[1:end-1]) + kernel = kernel_back(v[1:(end - 1)]) return ScaledKernel(kernel, S(exp(last(v)))) end return vcat(kernel_vec, T(log(k.σ²))), unflatten_to_scaledkernel @@ -92,5 +94,5 @@ function printshifted(io::IO, κ::ScaledKernel, shift::Int) for _ in 1:(shift + 1) print(io, "\t") end - print(io, "- σ² = ", κ.σ²) + return print(io, "- σ² = ", κ.σ²) end diff --git a/src/mokernels/independent.jl b/src/mokernels/independent.jl index d722146c0..10e8197fd 100644 --- a/src/mokernels/independent.jl +++ b/src/mokernels/independent.jl @@ -27,7 +27,9 @@ end function ParameterHandling.flatten(::Type{T}, k::IndependentMOKernel) where {T<:Real} vec, unflatten_to_kernel = flatten(T, k.kernel) - unflatten_to_independentmokernel(v::Vector{T}) = IndependentMOKernel(unflatten_to_kernel(v)) + function unflatten_to_independentmokernel(v::Vector{T}) + return IndependentMOKernel(unflatten_to_kernel(v)) + end return vec, unflatten_to_independentmokernel end diff --git a/src/test_utils.jl b/src/test_utils.jl index b791fa084..344994588 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -28,7 +28,7 @@ be of different lengths. `test_interface` offers certain types of test data generation to make running these tests require less code for common input types. For example, `Vector{<:Real}`, `ColVecs{<:Real}`, and `RowVecs{<:Real}` are supported. For other input vector types, please provide the data -manually. +manually. """ function test_interface( k::Kernel, @@ -91,7 +91,9 @@ function test_interface( @test kernelmatrix_diag!(tmp_diag, k, x0, x1) ≈ kernelmatrix_diag(k, x0, x1) # Check flatten/unflatten - test_flatten_interface(k) + ParameterHandling.TestUtils.test_flatten_interface(k) + + return nothing end function test_interface( diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index dc714cc0f..c39b25823 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -23,13 +23,11 @@ end ScaleTransform() = ScaleTransform(1.0) - - function ParameterHandling.flatten(::Type{T}, t::ScaleTransform{S}) where {T<:Real,S<:Real} s = t.s function unflatten_to_scaletransform(v::Vector{T}) length(v) == 1 || error("incorrect number of parameters") - ScaleTransform(S(first(v))) + return ScaleTransform(S(first(v))) end return T[s], unflatten_to_scaletransform end diff --git a/test/approximations/nystrom.jl b/test/approximations/nystrom.jl index 5e9c6773d..11be1055f 100644 --- a/test/approximations/nystrom.jl +++ b/test/approximations/nystrom.jl @@ -4,8 +4,8 @@ k = SqExponentialKernel() for obsdim in [1, 2] @test kernelmatrix(k, X; obsdim=obsdim) ≈ - kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim)) + kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim)) @test kernelmatrix(k, X; obsdim=obsdim) ≈ - kernelmatrix(nystrom(k, X, collect(1:dims[obsdim]); obsdim=obsdim)) + kernelmatrix(nystrom(k, X, collect(1:dims[obsdim]); obsdim=obsdim)) end end diff --git a/test/basekernels/fbm.jl b/test/basekernels/fbm.jl index 6ab3704d9..990026429 100644 --- a/test/basekernels/fbm.jl +++ b/test/basekernels/fbm.jl @@ -5,7 +5,7 @@ v1 = rand(rng, 3) v2 = rand(rng, 3) @test k(v1, v2) ≈ - ( + ( sqeuclidean(v1, zero(v1))^h + sqeuclidean(v2, zero(v2))^h - sqeuclidean(v1 - v2, zero(v1 - v2))^h ) / 2 atol = 1e-5 diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index dedbd3847..5a83de55b 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -45,7 +45,7 @@ k = Matern52Kernel() @test kappa(k, x) ≈ (1 + sqrt(5) * x + 5 / 3 * x^2)exp(-sqrt(5) * x) @test k(v1, v2) ≈ - ( + ( 1 + sqrt(5) * norm(v1 - v2) + 5 / 3 * norm(v1 - v2)^2 )exp(-sqrt(5) * norm(v1 - v2)) @test kappa(Matern52Kernel(), x) == kappa(k, x) diff --git a/test/basekernels/piecewisepolynomial.jl b/test/basekernels/piecewisepolynomial.jl index 26ec1a545..1c624271d 100644 --- a/test/basekernels/piecewisepolynomial.jl +++ b/test/basekernels/piecewisepolynomial.jl @@ -20,7 +20,7 @@ @test PiecewisePolynomialKernel(; dim=D) isa PiecewisePolynomialKernel{0} @test repr(k) == - "Piecewise Polynomial Kernel (degree = $(degree), ⌊dim/2⌋ = $(div(D, 2)), metric = Euclidean(0.0))" + "Piecewise Polynomial Kernel (degree = $(degree), ⌊dim/2⌋ = $(div(D, 2)), metric = Euclidean(0.0))" k3 = PiecewisePolynomialKernel(; degree=degree, dim=D, metric=WeightedEuclidean(ones(D)) diff --git a/test/basekernels/wiener.jl b/test/basekernels/wiener.jl index e8ba6dcf2..9dd60ba43 100644 --- a/test/basekernels/wiener.jl +++ b/test/basekernels/wiener.jl @@ -27,9 +27,9 @@ @test k0(v1, v2) ≈ minXY @test k1(v1, v2) ≈ 1 / 3 * minXY^3 + 1 / 2 * minXY^2 * euclidean(v1, v2) @test k2(v1, v2) ≈ - 1 / 20 * minXY^5 + 1 / 12 * minXY^3 * euclidean(v1, v2) * (X + Y - 1 / 2 * minXY) + 1 / 20 * minXY^5 + 1 / 12 * minXY^3 * euclidean(v1, v2) * (X + Y - 1 / 2 * minXY) @test k3(v1, v2) ≈ - 1 / 252 * minXY^7 + + 1 / 252 * minXY^7 + 1 / 720 * minXY^4 * euclidean(v1, v2) * diff --git a/test/kernels/gibbskernel.jl b/test/kernels/gibbskernel.jl index 3c1722dcd..f0109afe9 100644 --- a/test/kernels/gibbskernel.jl +++ b/test/kernels/gibbskernel.jl @@ -8,6 +8,6 @@ k_gibbs = GibbsKernel(ell) @test k_gibbs(x, y) ≈ - sqrt((2 * ell(x) * ell(y)) / (ell(x)^2 + ell(y)^2)) * + sqrt((2 * ell(x) * ell(y)) / (ell(x)^2 + ell(y)^2)) * exp(-(x - y)^2 / (ell(x)^2 + ell(y)^2)) end diff --git a/test/kernels/neuralkernelnetwork.jl b/test/kernels/neuralkernelnetwork.jl index 25a4da990..31e69621d 100644 --- a/test/kernels/neuralkernelnetwork.jl +++ b/test/kernels/neuralkernelnetwork.jl @@ -43,12 +43,12 @@ using KernelFunctions: NeuralKernelNetwork, LinearLayer, product, Primitive # Vector input. @test kernelmatrix_diag(nkn_add_kernel, x0) ≈ kernelmatrix_diag(sum_k, x0) @test kernelmatrix_diag(nkn_add_kernel, x0, x1) ≈ - kernelmatrix_diag(sum_k, x0, x1) + kernelmatrix_diag(sum_k, x0, x1) # ColVecs input. @test kernelmatrix_diag(nkn_add_kernel, X0) ≈ kernelmatrix_diag(sum_k, X0) @test kernelmatrix_diag(nkn_add_kernel, X0, X1) ≈ - kernelmatrix_diag(sum_k, X0, X1) + kernelmatrix_diag(sum_k, X0, X1) end @testset "product" begin nkn_prod_kernel = NeuralKernelNetwork(primitives, product) diff --git a/test/kernels/transformedkernel.jl b/test/kernels/transformedkernel.jl index 39176d96d..0295649bc 100644 --- a/test/kernels/transformedkernel.jl +++ b/test/kernels/transformedkernel.jl @@ -19,8 +19,8 @@ @test ktard(v1, v2) == (k ∘ ARDTransform(v))(v1, v2) @test ktard(v1, v2) == k(v .* v1, v .* v2) @test (k ∘ LinearTransform(P') ∘ ScaleTransform(s))(v1, v2) == - ((k ∘ LinearTransform(P')) ∘ ScaleTransform(s))(v1, v2) == - (k ∘ (LinearTransform(P') ∘ ScaleTransform(s)))(v1, v2) + ((k ∘ LinearTransform(P')) ∘ ScaleTransform(s))(v1, v2) == + (k ∘ (LinearTransform(P') ∘ ScaleTransform(s)))(v1, v2) @test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s)) diff --git a/test/matrix/kernelmatrix.jl b/test/matrix/kernelmatrix.jl index dc4b5ebc0..b5bf3fff8 100644 --- a/test/matrix/kernelmatrix.jl +++ b/test/matrix/kernelmatrix.jl @@ -132,11 +132,11 @@ KernelFunctions.kappa(::ToySimpleKernel, d) = exp(-d / 2) tmp_diag = Vector{Float64}(undef, length(x)) @test kernelmatrix_diag(k, x) ≈ - kernelmatrix_diag!(tmp_diag, k, X; obsdim=obsdim) + kernelmatrix_diag!(tmp_diag, k, X; obsdim=obsdim) @test kernelmatrix_diag(k, x) ≈ tmp_diag tmp_diag = Vector{Float64}(undef, length(x)) @test kernelmatrix_diag!(tmp_diag, k, X, X; obsdim=obsdim) ≈ - kernelmatrix_diag(k, x, x) + kernelmatrix_diag(k, x, x) @test tmp_diag ≈ kernelmatrix_diag(k, x, x) end end diff --git a/test/mokernels/independent.jl b/test/mokernels/independent.jl index 8a9ba733d..a1ce1c25a 100644 --- a/test/mokernels/independent.jl +++ b/test/mokernels/independent.jl @@ -27,6 +27,6 @@ @test eltype(typeof(kernelmatrix(k, x2))) <: Float32 @test string(k) == - "Independent Multi-Output Kernel\n" * + "Independent Multi-Output Kernel\n" * "\tSquared Exponential Kernel (metric = Euclidean(0.0))" end diff --git a/test/mokernels/intrinsiccoregion.jl b/test/mokernels/intrinsiccoregion.jl index 2d6ee9913..b08930441 100644 --- a/test/mokernels/intrinsiccoregion.jl +++ b/test/mokernels/intrinsiccoregion.jl @@ -27,9 +27,9 @@ @test icoregionkernel.B == B @test icoregionkernel.kernel == kernel @test icoregionkernel(XIF[1], XIF[1]) ≈ - B[XIF[1][2], XIF[1][2]] * kernel(XIF[1][1], XIF[1][1]) + B[XIF[1][2], XIF[1][2]] * kernel(XIF[1][1], XIF[1][1]) @test icoregionkernel(XIF[1], XIF[end]) ≈ - B[XIF[1][2], XIF[end][2]] * kernel(XIF[1][1], XIF[end][1]) + B[XIF[1][2], XIF[end][2]] * kernel(XIF[1][1], XIF[end][1]) # kernelmatrix KernelFunctions.TestUtils.test_interface(icoregionkernel, XIF, YIF, ZIF) @@ -43,5 +43,5 @@ test_ADs(icoregionkernel; dims=dims) @test string(icoregionkernel) == - string("Intrinsic Coregion Kernel: ", kernel, " with ", dims.out, " outputs") + string("Intrinsic Coregion Kernel: ", kernel, " with ", dims.out, " outputs") end diff --git a/test/utils.jl b/test/utils.jl index aacd27348..233803725 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -30,9 +30,9 @@ Y = randn(rng, D, N + 1) DY = ColVecs(Y) @test KernelFunctions.pairwise(SqEuclidean(), DX) ≈ - pairwise(SqEuclidean(), X; dims=2) + pairwise(SqEuclidean(), X; dims=2) @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈ - pairwise(SqEuclidean(), X, Y; dims=2) + pairwise(SqEuclidean(), X, Y; dims=2) @test vcat(DX, DY) isa ColVecs @test vcat(DX, DY).X == hcat(X, Y) K = zeros(N, N) @@ -87,9 +87,9 @@ Y = randn(rng, D + 1, N) DY = RowVecs(Y) @test KernelFunctions.pairwise(SqEuclidean(), DX) ≈ - pairwise(SqEuclidean(), X; dims=1) + pairwise(SqEuclidean(), X; dims=1) @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈ - pairwise(SqEuclidean(), X, Y; dims=1) + pairwise(SqEuclidean(), X, Y; dims=1) @test vcat(DX, DY) isa RowVecs @test vcat(DX, DY).X == vcat(X, Y) K = zeros(D, D) From d1fe6ced159a39a93bf885965be91fa1f0ec677b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 9 Nov 2021 21:00:49 +0100 Subject: [PATCH 03/13] Apply suggestions from code review --- src/basekernels/constant.jl | 3 +-- src/basekernels/exponential.jl | 5 ++--- src/basekernels/fbm.jl | 5 ++--- src/basekernels/matern.jl | 6 +----- src/basekernels/polynomial.jl | 10 ++++------ src/basekernels/rational.jl | 10 ++++------ src/transform/ardtransform.jl | 7 +++---- src/transform/scaletransform.jl | 8 ++------ 8 files changed, 19 insertions(+), 35 deletions(-) diff --git a/src/basekernels/constant.jl b/src/basekernels/constant.jl index 5df35a058..5d22d8309 100644 --- a/src/basekernels/constant.jl +++ b/src/basekernels/constant.jl @@ -79,8 +79,7 @@ ConstantKernel(; c::Real=1.0) = ConstantKernel(c) function ParameterHandling.flatten(::Type{T}, k::ConstantKernel{S}) where {T<:Real,S} function unflatten_to_constantkernel(v::Vector{T}) - length(v) == 1 || error("incorrect number of parameters") - return ConstantKernel(; c=S(exp(first(v)))) + return ConstantKernel(; c=S(exp(only(v)))) end return T[log(k.c)], unflatten_to_constantkernel end diff --git a/src/basekernels/exponential.jl b/src/basekernels/exponential.jl index 4e7e375fb..ab6232201 100644 --- a/src/basekernels/exponential.jl +++ b/src/basekernels/exponential.jl @@ -144,11 +144,10 @@ function ParameterHandling.flatten( ) where {T<:Real,S<:Real} metric = k.metric function unflatten_to_gammaexponentialkernel(v::Vector{T}) - length(v) == 1 || error("incorrect number of parameters") - γ = S(1 + logistic(first(v))) + γ = S(2 * logistic(only(v))) return GammaExponentialKernel(; γ=γ, metric=metric) end - return T[logit(k.γ - 1)], unflatten_to_gammaexponentialkernel + return T[logit(k.γ / 2)], unflatten_to_gammaexponentialkernel end kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^κ.γ) diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index 980747ff7..6db795a1a 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -25,11 +25,10 @@ FBMKernel(; h::Real=0.5) = FBMKernel(h) function ParameterHandling.flatten(::Type{T}, k::FBMKernel{S}) where {T<:Real,S<:Real} function unflatten_to_fbmkernel(v::Vector{T}) - length(v) == 1 || error("incorrect number of parameters") - h = S((1 + logistic(first(v))) / 2) + h = S(logistic(only(v))) return FBMKernel(h) end - return T[logit(2 * k.h - 1)], unflatten_to_fbmkernel + return T[logit(k.h)], unflatten_to_fbmkernel end function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index ffb76ffdd..d792b409f 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -33,11 +33,7 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, function ParameterHandling.flatten(::Type{T}, k::MaternKernel{S}) where {T<:Real,S<:Real} metric = k.metric - function unflatten_to_maternkernel(v::Vector{T}) - length(v) == 1 || error("incorrect number of parameters") - v = S(exp(first(v))) - return MaternKernel(ν, metric) - end + unflatten_to_maternkernel(v::Vector{T}) = MaternKernel(S(exp(first(v))), metric) return T[log(k.ν)], unflatten_to_maternkernel end diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index aa0965573..c299e4499 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -26,10 +26,9 @@ LinearKernel(; c::Real=0.0) = LinearKernel(c) function ParameterHandling.flatten(::Type{T}, k::LinearKernel{S}) where {T<:Real,S<:Real} function unflatten_to_linearkernel(v::Vector{T}) - length(v) == 1 || error("incorrect number of parameters") - return LinearKernel(S(first(v))) + return LinearKernel(S(exp(only(v)))) end - return T[k.c], unflatten_to_linearkernel + return T[log(k.c)], unflatten_to_linearkernel end kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + κ.c @@ -73,10 +72,9 @@ function ParameterHandling.flatten( ) where {T<:Real,S<:Real} degree = k.degree function unflatten_to_polynomialkernel(v::Vector{T}) - length(v) == 1 || error("incorrect number of parameters") - return PolynomialKernel(degree, S(first(v))) + return PolynomialKernel(degree, S(exp(only(v)))) end - return T[k.c], unflatten_to_polynomialkernel + return T[log(k.c)], unflatten_to_polynomialkernel end kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + κ.c)^κ.degree diff --git a/src/basekernels/rational.jl b/src/basekernels/rational.jl index 55318ada9..7c2cedb91 100644 --- a/src/basekernels/rational.jl +++ b/src/basekernels/rational.jl @@ -32,8 +32,7 @@ end function ParameterHandling.flatten(::Type{T}, k::RationalKernel{S}) where {T<:Real,S} metric = k.metric function unflatten_to_rationalkernel(v::Vector{T}) - length(v) == 1 || error("incorrect number of parameters") - return ConstantKernel(S(exp(first(v))), metric) + return ConstantKernel(S(exp(only(v))), metric) end return T[log(k.α)], unflatten_to_rationalkernel end @@ -82,8 +81,7 @@ function ParameterHandling.flatten( ) where {T<:Real,S} metric = k.metric function unflatten_to_rationalquadratickernel(v::Vector{T}) - length(v) == 1 || error("incorrect number of parameters") - return RationalQuadraticKernel(; α=S(exp(first(v))), metric=metric) + return RationalQuadraticKernel(; α=S(exp(only(v))), metric=metric) end return T[log(k.α)], unflatten_to_rationalquadratickernel end @@ -141,13 +139,13 @@ end function ParameterHandling.flatten( ::Type{T}, k::GammaRationalKernel{Tα,Tγ} ) where {T<:Real,Tα,Tγ} - vec = T[log(k.α), logit(k.γ - 1)] + vec = T[log(k.α), logit(k.γ / 2)] metric = k.metric function unflatten_to_gammarationalkernel(v::Vector{T}) length(v) == 2 || error("incorrect number of parameters") logα, logitγ = v α = Tα(exp(logα)) - γ = Tγ(1 + logistic(logitγ)) + γ = Tγ(2 * logistic(logitγ)) return GammaRationalKernel(; α=α, γ=γ, metric=metric) end return vec, unflatten_to_gammarationalkernel diff --git a/src/transform/ardtransform.jl b/src/transform/ardtransform.jl index 2f83ace09..4f7967d88 100644 --- a/src/transform/ardtransform.jl +++ b/src/transform/ardtransform.jl @@ -23,10 +23,9 @@ Create an [`ARDTransform`](@ref) with vector `fill(s, dims)`. """ ARDTransform(s::Real, dims::Integer) = ARDTransform(fill(s, dims)) -function ParameterHandling.flatten(::Type{T}, t::ARDTransform) where {T<:Real} - vec, back = flatten(T, t.v) - unflatten_to_ardtransform(v::Vector{T}) = ARDTransform(back(v)) - return vec, unflatten_to_ardtransform +function ParameterHandling.flatten(::Type{T}, t::ARDTransform{S}) where {T<:Real,S} + unflatten_to_ardtransform(v::Vector{T}) = ARDTransform(convert(S, map(exp, v))) + return convert(Vector{T}, map(log, t.v)), unflatten_to_ardtransform end function set!(t::ARDTransform{<:AbstractVector{T}}, ρ::AbstractVector{T}) where {T<:Real} diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index c39b25823..a073be28e 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -24,12 +24,8 @@ end ScaleTransform() = ScaleTransform(1.0) function ParameterHandling.flatten(::Type{T}, t::ScaleTransform{S}) where {T<:Real,S<:Real} - s = t.s - function unflatten_to_scaletransform(v::Vector{T}) - length(v) == 1 || error("incorrect number of parameters") - return ScaleTransform(S(first(v))) - end - return T[s], unflatten_to_scaletransform + unflatten_to_scaletransform(v::Vector{T}) = ScaleTransform(S(exp(only(v)))) + return T[log(t.s)], unflatten_to_scaletransform end (t::ScaleTransform)(x) = t.s * x From ec500c38e8b6caf27b594e9a0cbfa8404a12e0f9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 9 Nov 2021 22:53:22 +0100 Subject: [PATCH 04/13] Update src/utils.jl --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 9e080cdc2..8fe52bfc7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -234,7 +234,7 @@ macro noparams(T) ::Type{S}, x::$(esc(T)) ) where {S<:Real} unflatten(v::Vector{S}) = x - return v, unflatten + return x, unflatten end end end From 1170c145c60b420ba7abfe501b761e3a97f1c40a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 9 Nov 2021 22:54:56 +0100 Subject: [PATCH 05/13] Update src/utils.jl --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 8fe52bfc7..b7fd48a09 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -233,8 +233,8 @@ macro noparams(T) Base.@__doc__ function ParameterHandling.flatten( ::Type{S}, x::$(esc(T)) ) where {S<:Real} - unflatten(v::Vector{S}) = x - return x, unflatten + unflatten(::Vector{S}) = x + return S[], unflatten end end end From 9aa600f5e7979efe12193d25eea7c95e2f20cfe7 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 9 Nov 2021 23:16:22 +0100 Subject: [PATCH 06/13] Update src/transform/scaletransform.jl --- src/transform/scaletransform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index a073be28e..4be7f4ba3 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -16,7 +16,7 @@ struct ScaleTransform{T<:Real} <: Transform s::T function ScaleTransform(s::Real) - @check_args(ScaleTransform, s > zero(s), "s > 0") + @check_args(ScaleTransform, s, s > zero(s), "s > 0") return new{typeof(s)}(s) end end From 274f1e44e3f20efd0ebd2a79bcafbc1bc7f39c17 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 11 Nov 2021 01:39:29 +0100 Subject: [PATCH 07/13] Fix rational kernels --- src/KernelFunctions.jl | 2 +- src/basekernels/rational.jl | 4 +--- src/kernels/parameterkernel.jl | 10 ++++++++++ test/basekernels/rational.jl | 6 +++--- test/test_utils.jl | 2 +- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index a9f9afa8e..cdf9f25a5 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -58,7 +58,7 @@ using LinearAlgebra using Requires using SpecialFunctions: loggamma, besselk, polygamma using IrrationalConstants: logtwo, twoπ, invsqrt2 -using LogExpFunctions: softplus +using LogExpFunctions: logit, logistic, softplus using StatsBase using TensorCore using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield diff --git a/src/basekernels/rational.jl b/src/basekernels/rational.jl index 7c2cedb91..2ae231bbd 100644 --- a/src/basekernels/rational.jl +++ b/src/basekernels/rational.jl @@ -32,7 +32,7 @@ end function ParameterHandling.flatten(::Type{T}, k::RationalKernel{S}) where {T<:Real,S} metric = k.metric function unflatten_to_rationalkernel(v::Vector{T}) - return ConstantKernel(S(exp(only(v))), metric) + return RationalKernel(S(exp(only(v))), metric) end return T[log(k.α)], unflatten_to_rationalkernel end @@ -134,8 +134,6 @@ struct GammaRationalKernel{Tα<:Real,Tγ<:Real,M} <: SimpleKernel end end -@functor GammaRationalKernel - function ParameterHandling.flatten( ::Type{T}, k::GammaRationalKernel{Tα,Tγ} ) where {T<:Real,Tα,Tγ} diff --git a/src/kernels/parameterkernel.jl b/src/kernels/parameterkernel.jl index 0f925abe9..5d14d16c8 100644 --- a/src/kernels/parameterkernel.jl +++ b/src/kernels/parameterkernel.jl @@ -92,6 +92,16 @@ struct ParameterKernel{P,K} <: Kernel kernel::K end +# convenience function +""" + ParameterKernel(kernel::Kernel) + +Construct a `ParameterKernel` from an existing `kernel`. + +The constructor is a short-hand for `ParameterKernel(ParameterHandling.flatten(kernel)...)`. +""" +ParameterKernel(kernel::Kernel) = ParameterKernel(flatten(kernel)...) + Functors.@functor ParameterKernel (params,) function ParameterHandling.flatten(::Type{T}, kernel::ParameterKernel) where {T<:Real} diff --git a/test/basekernels/rational.jl b/test/basekernels/rational.jl index f1a4faa86..e67101b5f 100644 --- a/test/basekernels/rational.jl +++ b/test/basekernels/rational.jl @@ -28,7 +28,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(x -> RationalKernel(; alpha=exp(x[1])), [α]) - test_params(k, ([α],)) + test_params(k, ([log(α)],)) end @testset "RationalQuadraticKernel" begin @@ -55,7 +55,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(x -> RationalQuadraticKernel(; alpha=exp(x[1])), [α]) - test_params(k, ([α],)) + test_params(k, ([log(α)],)) end @testset "GammaRationalKernel" begin @@ -129,6 +129,6 @@ TestUtils.test_interface(k, Float64) a = 1.0 + rand() test_ADs(x -> GammaRationalKernel(; α=x[1], γ=x[2]), [a, 1 + 0.5 * rand()]) - test_params(GammaRationalKernel(; α=a, γ=x), ([a], [x])) + test_params(GammaRationalKernel(; α=a, γ=x), ([log(a), logit(x / 2)],)) end end diff --git a/test/test_utils.jl b/test/test_utils.jl index 22fe9fb08..bd3c40b50 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -22,7 +22,7 @@ function params(m...) end function test_params(kernel, reference) - params_kernel = params(kernel) + params_kernel = params(ParameterKernel(kernel)) params_reference = params(reference) @test length(params_kernel) == length(params_reference) From 408531f451a0329ea35bf4cecbe679cd07986a75 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 11 Nov 2021 01:40:49 +0100 Subject: [PATCH 08/13] Fix some other tests --- test/basekernels/constant.jl | 2 +- test/basekernels/exponential.jl | 2 +- test/basekernels/fbm.jl | 2 +- test/basekernels/matern.jl | 2 +- test/basekernels/periodic.jl | 2 +- test/basekernels/polynomial.jl | 4 ++-- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/basekernels/constant.jl b/test/basekernels/constant.jl index e18df9419..ea19f9ecd 100644 --- a/test/basekernels/constant.jl +++ b/test/basekernels/constant.jl @@ -32,7 +32,7 @@ @test metric(ConstantKernel()) == KernelFunctions.Delta() @test metric(ConstantKernel(; c=2.0)) == KernelFunctions.Delta() @test repr(k) == "Constant Kernel (c = $(c))" - test_params(k, ([c],)) + test_params(k, ([log(c)],)) # Standardised tests. TestUtils.test_interface(k, Float64) diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index a002bf29d..e6230954a 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -57,7 +57,7 @@ @test k2(v1, v2) ≈ k(v1, v2) test_ADs(γ -> GammaExponentialKernel(; gamma=first(γ)), [1 + 0.5 * rand()]) - test_params(k, ([γ],)) + test_params(k, ([logit(γ / 2)],)) TestUtils.test_interface(GammaExponentialKernel(; γ=1.36)) #Coherence : diff --git a/test/basekernels/fbm.jl b/test/basekernels/fbm.jl index 990026429..7cc1ce623 100644 --- a/test/basekernels/fbm.jl +++ b/test/basekernels/fbm.jl @@ -22,5 +22,5 @@ Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1] ) - test_params(k, ([h],)) + test_params(k, ([logit(h)],)) end diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 5a83de55b..daf447c7f 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -23,7 +23,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) - test_params(k, ([ν],)) + test_params(k, ([log(ν)],)) end @testset "Matern32Kernel" begin k = Matern32Kernel() diff --git a/test/basekernels/periodic.jl b/test/basekernels/periodic.jl index fb149dff5..a97a6ec86 100644 --- a/test/basekernels/periodic.jl +++ b/test/basekernels/periodic.jl @@ -17,5 +17,5 @@ # test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff]) @test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff" - test_params(k, (r,)) + test_params(k, (map(log, r),)) end diff --git a/test/basekernels/polynomial.jl b/test/basekernels/polynomial.jl index 00367b602..131efcd6f 100644 --- a/test/basekernels/polynomial.jl +++ b/test/basekernels/polynomial.jl @@ -19,7 +19,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(x -> LinearKernel(; c=x[1]), [c]) - test_params(LinearKernel(; c=c), ([c],)) + test_params(LinearKernel(; c=c), ([log(c)],)) end @testset "PolynomialKernel" begin k = PolynomialKernel() @@ -41,6 +41,6 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(x -> PolynomialKernel(; c=x[1]), [c]) - test_params(PolynomialKernel(; c=c), ([c],)) + test_params(PolynomialKernel(; c=c), ([log(c)],)) end end From 9e68e1330fa39c9e4c2dcd0b70e84ec1e3c4f52a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 11 Nov 2021 02:00:36 +0100 Subject: [PATCH 09/13] More fixes --- test/basekernels/exponential.jl | 2 +- test/basekernels/gabor.jl | 4 ++-- test/basekernels/matern.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index e6230954a..2012193d9 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -46,7 +46,7 @@ k = GammaExponentialKernel(; γ=γ) @test k(v1, v2) ≈ exp(-norm(v1 - v2)^γ) @test kappa(GammaExponentialKernel(), x) == kappa(k, x) - @test GammaExponentialKernel(; gamma=γ).γ == [γ] + @test GammaExponentialKernel(; gamma=γ).γ == γ @test metric(GammaExponentialKernel()) == Euclidean() @test metric(GammaExponentialKernel(; γ=2.0)) == Euclidean() @test repr(k) == "Gamma Exponential Kernel (γ = $(γ), metric = Euclidean(0.0))" diff --git a/test/basekernels/gabor.jl b/test/basekernels/gabor.jl index aa3047387..90c77cebe 100644 --- a/test/basekernels/gabor.jl +++ b/test/basekernels/gabor.jl @@ -13,8 +13,8 @@ TransformedKernel{<:CosineKernel,<:ScaleTransform}, }, } - @test k.kernels[1].transform.s[1] == inv(ell) - @test k.kernels[2].transform.s[1] == inv(p) + @test k.kernels[1].transform.s == inv(ell) + @test k.kernels[2].transform.s == inv(p) k_manual = exp(-sqeuclidean(v1, v2) / (2 * ell^2)) * cospi(euclidean(v1, v2) / p) @test k_manual ≈ k(v1, v2) atol = 1e-5 diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index daf447c7f..2db7cbe99 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -7,7 +7,7 @@ ν = 2.0 k = MaternKernel(; ν=ν) matern(x, ν) = 2^(1 - ν) / gamma(ν) * (sqrt(2ν) * x)^ν * besselk(ν, sqrt(2ν) * x) - @test MaternKernel(; nu=ν).ν == [ν] + @test MaternKernel(; nu=ν).ν == ν @test kappa(k, x) ≈ matern(x, ν) @test kappa(k, 0.0) == 1.0 @test kappa(MaternKernel(; ν=ν), x) == kappa(k, x) From 68d9f237c1b52a448c0b55ebe1f714ef28d9d770 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 11 Nov 2021 23:09:06 +0100 Subject: [PATCH 10/13] Update src/transform/ardtransform.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Théo Galy-Fajou --- src/transform/ardtransform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/ardtransform.jl b/src/transform/ardtransform.jl index 4f7967d88..72bc8f9f4 100644 --- a/src/transform/ardtransform.jl +++ b/src/transform/ardtransform.jl @@ -25,7 +25,7 @@ ARDTransform(s::Real, dims::Integer) = ARDTransform(fill(s, dims)) function ParameterHandling.flatten(::Type{T}, t::ARDTransform{S}) where {T<:Real,S} unflatten_to_ardtransform(v::Vector{T}) = ARDTransform(convert(S, map(exp, v))) - return convert(Vector{T}, map(log, t.v)), unflatten_to_ardtransform + return convert(Vector, map(T ∘ log, t.v)), unflatten_to_ardtransform end function set!(t::ARDTransform{<:AbstractVector{T}}, ρ::AbstractVector{T}) where {T<:Real} From bb8b6cd0d0a3f7797bc0e9405c4e70e02a057b8d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 11 Nov 2021 23:37:12 +0100 Subject: [PATCH 11/13] Fix Transform tests --- test/transform/scaletransform.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/transform/scaletransform.jl b/test/transform/scaletransform.jl index c0445b8a4..8b85d977f 100644 --- a/test/transform/scaletransform.jl +++ b/test/transform/scaletransform.jl @@ -14,10 +14,10 @@ @test all([t(x[n]) ≈ x′[n] for n in eachindex(x)]) end - s2 = 2.0 - KernelFunctions.set!(t, s2) - @test t.s == [s2] @test isequal(ScaleTransform(s), ScaleTransform(s)) - @test repr(t) == "Scale Transform (s = $(s2))" + + s2 = 2.0 + @test repr(ScaleTransform(s2)) == "Scale Transform (s = $(s2))" + test_ADs(x -> SEKernel() ∘ ScaleTransform(exp(x[1])), randn(rng, 1)) end From 01998cd2755884d5056820d52a7cb7987e4c5fc8 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 12 Nov 2021 00:18:54 +0100 Subject: [PATCH 12/13] Fix PolynomialKernel --- src/basekernels/polynomial.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index c299e4499..b04e53dc7 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -63,9 +63,7 @@ struct PolynomialKernel{T<:Real} <: SimpleKernel end end -function PolynomialKernel(; degree::Int=2, c::Real=0.0) - return PolynomialKernel{typeof(c)}(degree, [c]) -end +PolynomialKernel(; degree::Int=2, c::Real=0.0) = PolynomialKernel(degree, c) function ParameterHandling.flatten( ::Type{T}, k::PolynomialKernel{S} From 2ccf8cfc5f934e434f8e78a23a4b1fe3ddfbda39 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 12 Nov 2021 00:20:12 +0100 Subject: [PATCH 13/13] Update tests --- test/basekernels/constant.jl | 4 +++- test/basekernels/cosine.jl | 1 + test/basekernels/exponential.jl | 2 ++ test/basekernels/exponentiated.jl | 1 + test/basekernels/matern.jl | 2 ++ test/basekernels/nn.jl | 1 + test/basekernels/piecewisepolynomial.jl | 2 +- 7 files changed, 11 insertions(+), 2 deletions(-) diff --git a/test/basekernels/constant.jl b/test/basekernels/constant.jl index ea19f9ecd..fb9f9d0d7 100644 --- a/test/basekernels/constant.jl +++ b/test/basekernels/constant.jl @@ -8,6 +8,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, (Float64[],)) test_ADs(ZeroKernel) end @testset "WhiteKernel" begin @@ -21,6 +22,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, (Float64[],)) test_ADs(WhiteKernel) end @testset "ConstantKernel" begin @@ -32,10 +34,10 @@ @test metric(ConstantKernel()) == KernelFunctions.Delta() @test metric(ConstantKernel(; c=2.0)) == KernelFunctions.Delta() @test repr(k) == "Constant Kernel (c = $(c))" - test_params(k, ([log(c)],)) # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, ([log(c)],)) test_ADs(c -> ConstantKernel(; c=first(c)), [c]) end end diff --git a/test/basekernels/cosine.jl b/test/basekernels/cosine.jl index ed24e4923..6d3fea5c3 100644 --- a/test/basekernels/cosine.jl +++ b/test/basekernels/cosine.jl @@ -19,5 +19,6 @@ # Standardised tests. TestUtils.test_interface(k, Vector{Float64}) + test_params(k, (Float64[],)) test_ADs(CosineKernel) end diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index 2012193d9..fcb3339b2 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -21,6 +21,7 @@ # Standardised tests. TestUtils.test_interface(k) + test_params(k, (Float64[],)) test_ADs(SEKernel) end @testset "ExponentialKernel" begin @@ -39,6 +40,7 @@ # Standardised tests. TestUtils.test_interface(k) + test_params(k, (Float64[],)) test_ADs(ExponentialKernel) end @testset "GammaExponentialKernel" begin diff --git a/test/basekernels/exponentiated.jl b/test/basekernels/exponentiated.jl index 1d209dc49..14a10ab24 100644 --- a/test/basekernels/exponentiated.jl +++ b/test/basekernels/exponentiated.jl @@ -13,5 +13,6 @@ # Standardised tests. This kernel appears to be fairly numerically unstable. TestUtils.test_interface(k; atol=1e-3) + test_params(k, (Float64[],)) test_ADs(ExponentiatedKernel) end diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 2db7cbe99..1d24f2256 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -39,6 +39,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, (Float64[],)) test_ADs(Matern32Kernel) end @testset "Matern52Kernel" begin @@ -58,6 +59,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, (Float64[],)) test_ADs(Matern52Kernel) end @testset "Coherence Materns" begin diff --git a/test/basekernels/nn.jl b/test/basekernels/nn.jl index c9dabeb69..59c0f1717 100644 --- a/test/basekernels/nn.jl +++ b/test/basekernels/nn.jl @@ -7,5 +7,6 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, (Float64[],)) test_ADs(NeuralNetworkKernel) end diff --git a/test/basekernels/piecewisepolynomial.jl b/test/basekernels/piecewisepolynomial.jl index 1c624271d..1a424ef06 100644 --- a/test/basekernels/piecewisepolynomial.jl +++ b/test/basekernels/piecewisepolynomial.jl @@ -33,5 +33,5 @@ TestUtils.test_interface(k, RowVecs{Float64}; dim_in=2) test_ADs(() -> PiecewisePolynomialKernel{degree}(; dim=D)) - test_params(k, ()) + test_params(k, (Float64[],)) end