Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with Cholesky decompositions on simple benchmark #61

Open
dimitri-rusin opened this issue Oct 23, 2023 · 6 comments
Open

Issues with Cholesky decompositions on simple benchmark #61

dimitri-rusin opened this issue Oct 23, 2023 · 6 comments

Comments

@dimitri-rusin
Copy link

I get some issues with the Cholesky Decompositions. Here's how to reproduce it:

  1. Go to: https://colab.research.google.com/drive/1XftMKU7-tWj0cdWjH7XsfiDPBIKXAWZk#scrollTo=OuypIJ7do1qi
  2. Run all cells.
  3. Then, either we get the exception:
NanError: cholesky_cpu: 3716 of 3721 elements of the torch.Size([61, 61]) tensor are NaN.

or

NotPSDError: Matrix not positive definite after repeatedly adding jitter up to 1.0e-04.

What matrix is being decomposed here? Can I influence or change this matrix using some hyperparameters? What can I do?

Thank you!
exc

@Takui9
Copy link

Takui9 commented Nov 15, 2023

same problem. got issues when using contextual HEBO

@AntGro
Copy link
Collaborator

AntGro commented Nov 15, 2023

To improve the stability you can add upper constraints on the kernel lengthscales and modify HEBO/hebo/models/gp/gp_utils.py replacing it with this version:

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.

# This program is free software; you can redistribute it and/or modify it under
# the terms of the MIT license.

# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.

import numpy as np
import torch
import torch.nn as nn
from torch import FloatTensor, LongTensor


from gpytorch.kernels import MaternKernel, ScaleKernel, ProductKernel
from gpytorch.priors  import GammaPrior
from gpytorch.constraints.constraints import LessThan

from ..layers import EmbTransform

class DummyFeatureExtractor(nn.Module):
    def __init__(self, num_cont, num_enum, num_uniqs = None, emb_sizes = None):
        super().__init__()
        self.num_cont  = num_cont
        self.num_enum  = num_enum
        self.total_dim = num_cont
        if num_enum > 0:
            assert num_uniqs is not None
            self.emb_trans  = EmbTransform(num_uniqs, emb_sizes = emb_sizes)
            self.total_dim += self.emb_trans.num_out

    def forward(self, x : FloatTensor, xe : LongTensor):
        x_all = x
        if self.num_enum > 0:
            x_all = torch.cat([x, self.emb_trans(xe)], dim = 1)
        return x_all

def default_kern(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x = 1000):
    if fe is None:
        has_num  = x  is not None and x.shape[1]  > 0
        has_enum = xe is not None and xe.shape[1] > 0
        kerns    = []
        if has_num:
            ard_num_dims = x.shape[1] if ard_kernel else None
            kernel       = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = torch.arange(x.shape[1]),
                                        lengthscale_constraint=LessThan(5))
            if ard_kernel:
                lscales = kernel.lengthscale.detach().clone().view(1, -1)
                for i in range(x.shape[1]):
                    idx = np.random.choice(x.shape[0], min(x.shape[0], max_x), replace = False)
                    lscales[0, i] = torch.pdist(x[idx, i].view(-1, 1)).median().clamp(min = 0.02)
                kernel.lengthscale = lscales
            kerns.append(kernel)
        if has_enum:
            kernel = MaternKernel(nu = 1.5, active_dims = torch.arange(x.shape[1], total_dim),
                                        lengthscale_constraint=LessThan(5))
            kerns.append(kernel)
        final_kern = ScaleKernel(ProductKernel(*kerns), outputscale_prior = GammaPrior(0.5, 0.5))
        final_kern.outputscale = y[torch.isfinite(y)].var()
        return final_kern
    else:
        if ard_kernel:
            kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim,
                                        lengthscale_constraint=LessThan(5)))
        else:
            kernel = ScaleKernel(MaternKernel(nu = 1.5))
        kernel.outputscale = y[torch.isfinite(y)].var()
        return kernel

I've tested it on the example provided in the original issue comment and it runs without errors.

After further investigation we'll include this to the repo directly.

dimitri-rusin pushed a commit to dimitri-rusin/hebo_on_bbob that referenced this issue Nov 16, 2023
…the function `default_kern_rd` in the file as well.
dimitri-rusin pushed a commit to dimitri-rusin/hebo_on_bbob that referenced this issue Nov 16, 2023
dimitri-rusin pushed a commit to dimitri-rusin/hebo_on_bbob that referenced this issue Nov 16, 2023
dimitri-rusin pushed a commit to dimitri-rusin/hebo_on_bbob that referenced this issue Nov 20, 2023
The change from huawei-noah/HEBO#61 (comment) is integrated here. With a configuration file whose access mode is changed.
@AntGro
Copy link
Collaborator

AntGro commented Nov 29, 2023

@dimitri-rusin
Actually we also need to specify a lower bound to the lengthscales to prevent NaN or Inf. So I replace LessThan by Interval.

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.

# This program is free software; you can redistribute it and/or modify it under
# the terms of the MIT license.

# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.

import numpy as np
import torch
import torch.nn as nn
from torch import FloatTensor, LongTensor


from gpytorch.kernels import MaternKernel, ScaleKernel, ProductKernel
from gpytorch.priors  import GammaPrior
from gpytorch.constraints.constraints import Interval

from ..layers import EmbTransform

class DummyFeatureExtractor(nn.Module):
    def __init__(self, num_cont, num_enum, num_uniqs = None, emb_sizes = None):
        super().__init__()
        self.num_cont  = num_cont
        self.num_enum  = num_enum
        self.total_dim = num_cont
        if num_enum > 0:
            assert num_uniqs is not None
            self.emb_trans  = EmbTransform(num_uniqs, emb_sizes = emb_sizes)
            self.total_dim += self.emb_trans.num_out

    def forward(self, x : FloatTensor, xe : LongTensor):
        x_all = x
        if self.num_enum > 0:
            x_all = torch.cat([x, self.emb_trans(xe)], dim = 1)
        return x_all

def default_kern(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x = 1000):
    if fe is None:
        has_num  = x  is not None and x.shape[1]  > 0
        has_enum = xe is not None and xe.shape[1] > 0
        kerns    = []
        if has_num:
            ard_num_dims = x.shape[1] if ard_kernel else None
            kernel       = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = torch.arange(x.shape[1]),
                                        lengthscale_constraint=Interval(1e-5, 5))
            if ard_kernel:
                lscales = kernel.lengthscale.detach().clone().view(1, -1)
                for i in range(x.shape[1]):
                    idx = np.random.choice(x.shape[0], min(x.shape[0], max_x), replace = False)
                    lscales[0, i] = torch.pdist(x[idx, i].view(-1, 1)).median().clamp(min = 0.02)
                kernel.lengthscale = lscales
            kerns.append(kernel)
        if has_enum:
            kernel = MaternKernel(nu = 1.5, active_dims = torch.arange(x.shape[1], total_dim),
                                        lengthscale_constraint=Interval(1e-5, 5))
            kerns.append(kernel)
        final_kern = ScaleKernel(ProductKernel(*kerns), outputscale_prior = GammaPrior(0.5, 0.5))
        final_kern.outputscale = y[torch.isfinite(y)].var()
        return final_kern
    else:
        if ard_kernel:
            kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim,
                                        lengthscale_constraint=Interval(1e-5, 5)))
        else:
            kernel = ScaleKernel(MaternKernel(nu = 1.5))
        kernel.outputscale = y[torch.isfinite(y)].var()
        return kernel

@kegl
Copy link

kegl commented Jan 4, 2024

Actually it's this, right? default_kern_rd is also needed.

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.

# This program is free software; you can redistribute it and/or modify it under
# the terms of the MIT license.

# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.

import numpy as np
import torch
import torch.nn as nn
from torch import FloatTensor, LongTensor


from gpytorch.kernels import MaternKernel, ScaleKernel, ProductKernel
from gpytorch.priors  import GammaPrior
from gpytorch.constraints.constraints import Interval

from ..layers import EmbTransform

class DummyFeatureExtractor(nn.Module):
    def __init__(self, num_cont, num_enum, num_uniqs = None, emb_sizes = None):
        super().__init__()
        self.num_cont  = num_cont
        self.num_enum  = num_enum
        self.total_dim = num_cont
        if num_enum > 0:
            assert num_uniqs is not None
            self.emb_trans  = EmbTransform(num_uniqs, emb_sizes = emb_sizes)
            self.total_dim += self.emb_trans.num_out

    def forward(self, x : FloatTensor, xe : LongTensor):
        x_all = x
        if self.num_enum > 0:
            x_all = torch.cat([x, self.emb_trans(xe)], dim = 1)
        return x_all

def default_kern(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x = 1000):
    if fe is None:
        has_num  = x  is not None and x.shape[1]  > 0
        has_enum = xe is not None and xe.shape[1] > 0
        kerns    = []
        if has_num:
            ard_num_dims = x.shape[1] if ard_kernel else None
            kernel       = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = torch.arange(x.shape[1]),
                                        lengthscale_constraint=Interval(1e-5, 5))
            if ard_kernel:
                lscales = kernel.lengthscale.detach().clone().view(1, -1)
                for i in range(x.shape[1]):
                    idx = np.random.choice(x.shape[0], min(x.shape[0], max_x), replace = False)
                    lscales[0, i] = torch.pdist(x[idx, i].view(-1, 1)).median().clamp(min = 0.02)
                kernel.lengthscale = lscales
            kerns.append(kernel)
        if has_enum:
            kernel = MaternKernel(nu = 1.5, active_dims = torch.arange(x.shape[1], total_dim),
                                        lengthscale_constraint=Interval(1e-5, 5))
            kerns.append(kernel)
        final_kern = ScaleKernel(ProductKernel(*kerns), outputscale_prior = GammaPrior(0.5, 0.5))
        final_kern.outputscale = y[torch.isfinite(y)].var()
        return final_kern
    else:
        if ard_kernel:
            kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim,
                                        lengthscale_constraint=Interval(1e-5, 5)))
        else:
            kernel = ScaleKernel(MaternKernel(nu = 1.5))
        kernel.outputscale = y[torch.isfinite(y)].var()
        return kernel
    
def default_kern_rd(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x = 1000, E=0.2):
    '''
    Get a default kernel with random decompositons. 0 <= E <=1 specifies random tree conectivity.
    '''
    kernels = []
    random_graph = get_random_graph(total_dim, E)
    for clique in random_graph:
        if fe is None:
            num_dims  = tuple(dim for dim in clique if dim < x.shape[1])
            enum_dims = tuple(dim for dim in clique if x.shape[1] <= dim < total_dim)
            clique_kernels = []
            if len(num_dims) > 0:
                ard_num_dims = len(num_dims) if ard_kernel else None
                num_kernel       = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = num_dims)
                if ard_kernel:
                    lscales = num_kernel.lengthscale.detach().clone().view(1, -1)
                    if len(num_dims) > 1 :
                        for dim_no, dim_name in enumerate(num_dims):
                            idx = np.random.choice(num_dims, min(len(num_dims), max_x), replace = False)
                            lscales[0, dim_no] = torch.pdist(x[idx, dim_name].view(-1, 1)).median().clamp(min = 0.02)
                    num_kernel.lengthscale = lscales
                clique_kernels.append(num_kernel)
            if len(enum_dims) > 0:
                enum_kernel = MaternKernel(nu = 1.5, active_dims = enum_dims)
                clique_kernels.append(enum_kernel)
            
            kernel = ScaleKernel(ProductKernel(*clique_kernels), outputscale_prior = GammaPrior(0.5, 0.5))
        else:
            if ard_kernel:
                kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim, active_dims=tuple(clique)))
            else:
                kernel = ScaleKernel(MaternKernel(nu = 1.5, active_dims=tuple(clique)))
            
        kernels.append(kernel)

    final_kern = ScaleKernel(AdditiveKernel(*kernels), outputscale_prior = GammaPrior(0.5, 0.5))
    final_kern.outputscale = y[torch.isfinite(y)].var()
    return final_kern

@kegl
Copy link

kegl commented Jan 15, 2024

I tried this and it was not enough. What worked was catching the jitter fail and increasing jitter till it works.

    def fit(self, Xc : Tensor, Xe : Tensor, y : Tensor):
        Xc, Xe, y = filter_nan(Xc, Xe, y, 'all')
        self.fit_scaler(Xc, Xe, y)
        Xc, Xe, y = self.xtrans(Xc, Xe, y)

        assert(Xc.shape[1] == self.num_cont)
        assert(Xe.shape[1] == self.num_enum)
        assert(y.shape[1]  == self.num_out)

        self.Xc = Xc
        self.Xe = Xe
        self.y  = y

        n_constr = GreaterThan(self.noise_lb)
        n_prior  = LogNormalPrior(np.log(self.noise_guess), 0.5)
        self.lik = GaussianLikelihood(noise_constraint = n_constr, noise_prior = n_prior)
        self.gp  = GPyTorchModel(self.Xc, self.Xe, self.y, self.lik, **self.conf)

        self.gp.likelihood.noise  = max(1e-2, self.noise_lb)

        self.gp.train()
        self.lik.train()

        if self.optimizer.lower() == 'lbfgs':
            opt = torch.optim.LBFGS(self.gp.parameters(), lr = self.lr, max_iter = 5, line_search_fn = 'strong_wolfe')
        elif self.optimizer == 'psgld':
            opt = pSGLD(self.gp.parameters(), lr = self.lr, factor = 1. / y.shape[0], pretrain_step = self.num_epochs // 10)
        else:
            opt = torch.optim.Adam(self.gp.parameters(), lr = self.lr)
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.lik, self.gp)
        for epoch in range(self.num_epochs):
            jitter = 10 ** -8
            cont = True
            while cont:
                cont = False
                cholesky_jitter._set_value(
                    double_value=jitter, float_value=100*jitter, half_value=10000*jitter)
                def closure():
                    dist = self.gp(self.Xc, self.Xe)
                    loss = -1 * mll(dist, self.y.squeeze())
                    opt.zero_grad()
                    loss.backward()
                    return loss
                try:
                    opt.step(closure)
                except:
                    jitter *= 10
                    cont = True
                    print(f'jitter = {jitter}')
            if self.verbose and ((epoch + 1) % self.print_every == 0 or epoch == 0):
                print('After %d epochs, loss = %g' % (epoch + 1, closure().item()), flush = True)
        self.gp.eval()
        self.lik.eval()

    def predict(self, Xc, Xe):
        Xc, Xe = self.xtrans(Xc, Xe)
        with gpytorch.settings.fast_pred_var(), gpytorch.settings.debug(False):
            jitter = 10 ** -8
            cont = True
            while cont:
                cont = False
                cholesky_jitter._set_value(
                    double_value=jitter, float_value=100*jitter, half_value=10000*jitter)
                try:
                    pred = self.gp(Xc, Xe)
                except:
                    jitter *= 10
                    cont = True
                    print(f'jitter = {jitter}')                
            if self.pred_likeli:
                pred = self.lik(pred)
            mu_  = pred.mean.reshape(-1, self.num_out)
            var_ = pred.variance.reshape(-1, self.num_out)
        mu  = self.yscaler.inverse_transform(mu_)
        var = var_ * self.yscaler.std**2
        return mu, var.clamp(min = torch.finfo(var.dtype).eps)

kegl added a commit to kegl/HEBO that referenced this issue Jan 15, 2024
This was referenced Jan 15, 2024
@muazhari
Copy link

muazhari commented Aug 11, 2024

Are there any updates and fixes?
Until now, I am still getting the error, even though it is in a simple use case.
Reproducible code: https://gist.github.com/muazhari/85b7469902cdb7b3fba49a065b212f40

NanError: cholesky_cpu: 225 of 225 elements of the torch.Size([15, 15]) tensor are NaN.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants