Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Vecchia approximation #147

Merged
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.4.5"
version = "0.4.6"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand All @@ -12,11 +12,13 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
3 changes: 3 additions & 0 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ include("LaplaceApproximationModule.jl")
@reexport using .LaplaceApproximationModule:
build_laplace_objective, build_laplace_objective!

include("NearestNeighborsModule.jl")
@reexport using .NearestNeighborsModule: NearestNeighbors

include("deprecations.jl")

include("TestUtils.jl")
Expand Down
115 changes: 115 additions & 0 deletions src/NearestNeighborsModule.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
module NearestNeighborsModule
using ..API
using ChainRulesCore
using KernelFunctions, LinearAlgebra, SparseArrays, AbstractGPs, IrrationalConstants

"""
Constructs the matrix ``B`` for which ``f = Bf + \epsilon`` where ``f``
are the values of the GP and ``\epsilon`` is a vector of zero mean
independent Gaussian noise.
This matrix builds the conditional mean for function value ``f_i``
in terms of the function values for previous ``f_j``, where ``j < i``.
See equation (9) of (Datta, A. Nearest neighbor sparse Cholesky
matrices in spatial statistics. 2022).
"""
function make_B(pts::AbstractVector{T}, k::Int, kern::Kernel) where {T}
samanklesaria marked this conversation as resolved.
Show resolved Hide resolved
rows = make_rows(pts, k, kern)
js = make_js(rows, k)
is = make_is(js)
n = length(pts)
return sparse(reduce(vcat, is), reduce(vcat, js), reduce(vcat, rows), n, n)
end

function make_rows(pts::AbstractVector{T}, k::Int, kern::Kernel) where {T}
return [make_row(kern, pts[max(1, i - k):(i - 1)], pts[i]) for i in 2:length(pts)]
end

function make_row(kern::Kernel, ns::AbstractVector{T}, p::T) where {T}
return kernelmatrix(kern, ns) \ kern.(ns, p)
end

function make_js(rows, k)
return map(zip(rows, 2:(length(rows) + 1))) do (row, i)
start_ix = max(i - k, 1)
return start_ix:(start_ix + length(row) - 1)
end
end

make_is(js) = [fill(i, length(col_ix)) for (col_ix, i) in zip(js, 2:(length(js) + 1))]

"""
Constructs the diagonal covariance matrix for noise vector ``\epsilon``
for which ``f = Bf + \epsilon``.
See equation (10) of (Datta, A. Nearest neighbor sparse Cholesky
matrices in spatial statistics. 2022).
"""
function make_F(pts::AbstractVector, k::Int, kern::Kernel)
n = length(pts)
vals = [
begin
prior = kern(pts[i], pts[i])
if i == 1
prior
else
ns = pts[max(1, i - k):(i - 1)]
ki = kern.(ns, pts[i])
prior - dot(ki, kernelmatrix(kern, ns) \ ki)
end
end for i in 1:n
]
return Diagonal(vals)
end

@doc raw"""
In a ``k``-nearest neighbor (or Vecchia) Gaussian Process approximation,
we assume that the joint distribution ``p(f_1, f_2, f_3, \dotsc)``
factors as ``\prod_i p(f_i | f_{i-1}, \dotsc f_{i-k})``, where each ``f_i``
is only influenced by its ``k`` previous neighbors. This allows us to express
the vector ``f`` as ``Bf + \epsilon`` where ``B`` is a sparse matrix with only
``k`` entries per row and ``\epsilon`` is Gaussian distributed with diagonal
covariance ``F``. The precision matrix of the Gaussian process at the
specified points simplifies to ``(I-B)'F^{-1}(I-B)``.
"""
struct NearestNeighbors
samanklesaria marked this conversation as resolved.
Show resolved Hide resolved
k::Int
end

"`InvRoot(U)` is a lazy representation of `inv(UU')`"
struct InvRoot{A}
U::A
end

LinearAlgebra.logdet(A::InvRoot) = -2 * logdet(A.U)

function AbstractGPs.diag_Xt_invA_X(A::InvRoot, X::AbstractVecOrMat)
return AbstractGPs.diag_At_A(A.U' * X)
end

AbstractGPs.Xt_invA_X(A::InvRoot, X::AbstractVecOrMat) = AbstractGPs.At_A(A.U' * X)

# Make a sparse approximation of the square root of the precision matrix
function approx_root_prec(x::AbstractVector, k::Int, kern::Kernel)
F = make_F(x, k, kern)
B = make_B(x, k, kern)
return UpperTriangular((I - B)' * inv(sqrt(F)))
end

function AbstractGPs.posterior(
nn::NearestNeighbors, fx::AbstractGPs.FiniteGP, y::AbstractVector
)
kern = fx.f.kernel
U = approx_root_prec(fx.x, nn.k, kern)
δ = y - mean(fx)
α = U * (U' * δ)
C = InvRoot(U)
return AbstractGPs.PosteriorGP(fx.f, (α=α, C=C, x=fx.x, δ=δ))
end

function API.approx_lml(nn::NearestNeighbors, fx::AbstractGPs.FiniteGP, y::AbstractVector)
post = posterior(nn, fx, y)
quadform = post.data.α' * post.data.δ
ld = logdet(post.data.C)
return -(ld + length(y) * eltype(y)(log2π) + quadform) / 2
end

end
41 changes: 41 additions & 0 deletions test/NearestNeighborsModule.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
@testset "nearest_neighbors" begin
x = [1.0, 2.0, 3.5, 4.2, 5.9, 8.0]
kern = SqExponentialKernel()
fx = GP(kern)(x, 0.0)
x2 = 1.0:0.1:8
y = sin.(x)

@testset "Using all neighbors is the same as the exact GP" begin
opt_pred = mean_and_cov(posterior(NearestNeighbors(length(x) - 1), fx, y)(x2))
pred = mean_and_cov(posterior(fx, y)(x2))
for i in 1:2
@test all(isapprox.(opt_pred[i], pred[i]; atol=1e-4))
end
end

@testset "Using nearest neighbors approximates the exact GP" begin
opt_pred = mean_and_cov(posterior(NearestNeighbors(3), fx, y)(x2))
pred = mean_and_cov(posterior(fx, y)(x2))
for i in 1:2
@test all(isapprox.(opt_pred[i], pred[i]; atol=1e-1))
end
end

@testset "Using nearest neighbors approximates the exact log likelihood" begin
l1 = approx_lml(NearestNeighbors(3), fx, y)
l2 = logpdf(fx, y)
@test isapprox(l1, l2; atol=1e-2)
end

@testset "Zygote can take gradients of the logpdf" begin
function objective(lengthscale::Float64)
kern2 = with_lengthscale(kern, lengthscale)
fx = GP(kern2)(x, 0.0)
return approx_lml(NearestNeighbors(3), fx, y)
end
lml, grads = Zygote.withgradient(objective, 1.0)

@test approx_lml(NearestNeighbors(3), fx, y) ≈ lml
@test all(abs.(grads) .> 0)
end
end
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ using Zygote

using AbstractGPs
using ApproximateGPs
using ApproximateGPs: SparseVariationalApproximationModule, LaplaceApproximationModule
using ApproximateGPs:
SparseVariationalApproximationModule, LaplaceApproximationModule, NearestNeighborsModule

const GROUP = get(ENV, "GROUP", "All")

Expand Down Expand Up @@ -61,6 +62,10 @@ include("test_utils.jl")
include("LaplaceApproximationModule.jl")
println(" ")
@info "Ran laplace tests"

include("NearestNeighborsModule.jl")
println(" ")
@info "Ran nearest neighbors tests"
end

if GROUP == "All" || GROUP == "CUDA"
Expand Down
Loading