Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] UMAP fails to correctly embed small datasets when random_state is not set. #6024

Open
kc-howe opened this issue Aug 15, 2024 · 4 comments
Labels
? - Needs Triage Need team to review and classify bug Something isn't working

Comments

@kc-howe
Copy link

kc-howe commented Aug 15, 2024

Describe the bug

UMAP fails to correctly embed small datasets when random_state is not set (or, rather, when it is set to None). This affects dataset sizes smaller than roughly 90 samples.

umap_bug

Steps/Code to reproduce bug

Running any dataset under 90 samples through unseeded/seeded UMAP will be enough to reproduce this bug. Below is a simple example using make_blobs data.

from cuml import UMAP
from sklearn.datasets import make_blobs

X, y = make_blobs(n_samples=[15, 15, 15], n_features=5, random_state=42)

mapper = UMAP(n_neighbors=5, random_state=None)
mapper_seeded = UMAP(n_neighbors=5, random_state=42)

embedding = mapper.fit_transform(X)
embedding_seeded = mapper_seeded.fit_transform(X)

The output looks something like this:
umap_bug_simple

Expected behavior
Failing to pass a random state seed should not interfere with the quality of UMAP embeddings.

Environment details (please complete the following information):

  • Environment location: Bare-metal
  • Linux Distro/Architecture: Pop!_OS 22.04 LTS x86_64
  • GPU Model/Driver: GeForce RTX 4070 / Driver 550
  • CUDA: 12.3
  • Method of cuDF & cuML install: conda
# packages in environment at /home/kenneth/miniconda3:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
airborn                   0.0.0                    pypi_0    pypi
archspec                  0.2.1              pyhd3eb1b0_0  
asttokens                 2.0.5              pyhd3eb1b0_0  
blinker                   1.7.0                    pypi_0    pypi
boltons                   23.0.0          py310h06a4308_0  
bzip2                     1.0.8                h7b6447c_0  
c-ares                    1.19.1               h5eee18b_0  
ca-certificates           2023.12.12           h06a4308_0  
caerus                    0.1.9                    pypi_0    pypi
certifi                   2024.2.2        py310h06a4308_0  
cffi                      1.16.0          py310h5eee18b_0  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
click                     8.1.7                    pypi_0    pypi
comm                      0.1.2           py310h06a4308_0  
conda                     24.1.1          py310h06a4308_0  
conda-content-trust       0.1.3           py310h06a4308_0  
conda-libmamba-solver     24.1.0             pyhd3eb1b0_0  
conda-package-handling    2.2.0           py310h06a4308_0  
conda-package-streaming   0.9.0           py310h06a4308_0  
configparser              6.0.1                    pypi_0    pypi
contourpy                 1.1.1                    pypi_0    pypi
cryptography              39.0.1          py310h9ce1e76_2  
cubinlinker-cu11          0.3.0.post2              pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
cython                    0.29.36                  pypi_0    pypi
dash                      2.16.1                   pypi_0    pypi
dash-core-components      2.0.0                    pypi_0    pypi
dash-html-components      2.0.0                    pypi_0    pypi
dash-table                5.0.0                    pypi_0    pypi
debugpy                   1.6.7           py310h6a678d5_0  
decorator                 5.1.1              pyhd3eb1b0_0  
distro                    1.8.0           py310h06a4308_0  
dtwalign                  0.1.1                    pypi_0    pypi
exceptiongroup            1.2.0           py310h06a4308_0  
executing                 0.8.3              pyhd3eb1b0_0  
findpeaks                 2.5.4                    pypi_0    pypi
flask                     3.0.2                    pypi_0    pypi
fmt                       9.1.0                hdb19cb5_0  
fonttools                 4.43.1                   pypi_0    pypi
hdbscan                   0.8.33                   pypi_0    pypi
icu                       73.1                 h6a678d5_0  
idna                      3.4             py310h06a4308_0  
importlib-metadata        7.0.2                    pypi_0    pypi
iniconfig                 2.0.0                    pypi_0    pypi
ipykernel                 6.25.2             pyh2140261_0    conda-forge
ipython                   8.20.0          py310h06a4308_0  
itsdangerous              2.1.2                    pypi_0    pypi
jedi                      0.18.1          py310h06a4308_1  
joblib                    1.3.2                    pypi_0    pypi
jsonpatch                 1.32               pyhd3eb1b0_0  
jsonpointer               2.1                pyhd3eb1b0_0  
jupyter_client            8.6.0           py310h06a4308_0  
jupyter_core              5.5.0           py310h06a4308_0  
kiwisolver                1.4.5                    pypi_0    pypi
krb5                      1.20.1               h143b758_1  
ld_impl_linux-64          2.38                 h1181459_1  
libarchive                3.6.2                h6ac8c49_2  
libcurl                   8.5.0                h251f7ec_0  
libedit                   3.1.20230828         h5eee18b_0  
libev                     4.33                 h7f8727e_1  
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libmamba                  1.5.6                haf1ee3a_0  
libmambapy                1.5.6           py310h2dafd23_0  
libnghttp2                1.57.0               h2d74bed_0  
libsodium                 1.0.18               h7b6447c_0  
libsolv                   0.7.24               he621ea3_0  
libssh2                   1.10.0               hdbd6064_2  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
libxml2                   2.10.4               hf1b16e4_1  
llvmlite                  0.41.0                   pypi_0    pypi
lz4-c                     1.9.4                h6a678d5_0  
matplotlib                3.8.0                    pypi_0    pypi
matplotlib-inline         0.1.6           py310h06a4308_0  
menuinst                  2.0.2           py310h06a4308_0  
mined-thd                 0.8.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
nest-asyncio              1.5.6           py310h06a4308_0  
networkx                  3.1                      pypi_0    pypi
numba                     0.58.0                   pypi_0    pypi
numpy                     1.25.2                   pypi_0    pypi
openssl                   3.0.13               h7f8727e_0  
packaging                 23.1            py310h06a4308_0  
pandas                    2.1.1                    pypi_0    pypi
parso                     0.8.3              pyhd3eb1b0_0  
pcre2                     10.42                hebb0a14_0  
peakdetect                1.1                      pypi_0    pypi
pexpect                   4.8.0              pyhd3eb1b0_3  
pillow                    10.0.1                   pypi_0    pypi
pip                       22.3.1          py310h06a4308_0  
platformdirs              3.10.0          py310h06a4308_0  
plotly                    5.17.0                   pypi_0    pypi
pluggy                    1.4.0                    pypi_0    pypi
prompt-toolkit            3.0.43          py310h06a4308_0  
prompt_toolkit            3.0.43               hd3eb1b0_0  
protonup                  0.1.5                    pypi_0    pypi
psutil                    5.9.0           py310h5eee18b_0  
ptxcompiler-cu11          0.8.1.post1              pypi_0    pypi
ptyprocess                0.7.0              pyhd3eb1b0_2  
pure_eval                 0.2.2              pyhd3eb1b0_0  
pybind11-abi              4                    hd3eb1b0_1  
pycosat                   0.6.6           py310h5eee18b_0  
pycparser                 2.21               pyhd3eb1b0_0  
pygments                  2.15.1          py310h06a4308_1  
pygraphviz                1.11                     pypi_0    pypi
pynndescent               0.5.10                   pypi_0    pypi
pyparsing                 3.1.1                    pypi_0    pypi
pytest                    8.0.0                    pypi_0    pypi
python                    3.10.13              h955ad1f_0  
pytho

Additional context
This behavior is independent of choice in epochs and initialization algorithm.

@kc-howe kc-howe added ? - Needs Triage Need team to review and classify bug Something isn't working labels Aug 15, 2024
@viclafargue
Copy link
Contributor

In deterministic mode, the iterations are processed differently. However, confirming if this explains the difference observed here would have to be investigated further. However, smaller datasets definitely require larger n_epochs value for convergence. For small datasets, we recommend 500 epochs. Also, in case you are using nn_descent as the build_algo, I believe that it is not suitable for very small datasets (< 150 rows).

@kc-howe
Copy link
Author

kc-howe commented Aug 19, 2024

Running with higher n_epochs doesn't appear to resolve the issue. Here is the result of the above example code run with n_epochs=200_000:

umap_bug_200k_epochs

Given that the example code above leaves most parameters up to their defaults, it should be the case that build_algo defaults to 'auto' and therefore resorts to KNN brute force per the documentation (and the source code itself). Just to be sure, I manually set build_algo='brute_force_knn' and got the following warning:

[I] [12:50:27.494401] Unused keyword parameter: build_algo during cuML estimator initialization

Checking out where this warning likely originates from, I suspect this is related to the fact that neither build_algo nor build_kwds appear in UMAP's init signature when running inspect.signature. Here's the output of running list(inspect.signature(UMAP.__init__).parameters.keys()):

['self',
 'n_neighbors',
 'n_components',
 'metric',
 'metric_kwds',
 'n_epochs',
 'learning_rate',
 'min_dist',
 'spread',
 'set_op_mix_ratio',
 'local_connectivity',
 'repulsion_strength',
 'negative_sample_rate',
 'transform_queue_size',
 'init',
 'a',
 'b',
 'target_n_neighbors',
 'target_weight',
 'target_metric',
 'hash_input',
 'random_state',
 'precomputed_knn',
 'callback',
 'handle',
 'verbose',
 'output_type']

I can confirm that when either specifying build_algo in the init or leaving it to default, the resulting UMAP object has no build_algo attribute. Attempting to access it raises an AttributeError. Looking at the init code, I can't immediately see how this is the case, but I'm sure you guys will have a better idea.

@viclafargue
Copy link
Contributor

viclafargue commented Aug 19, 2024

Which version of RAPIDS are you using? This error should simply mean that the build_algo parameter is not present (probably version 24.06 and earlier). It should not be related to the issue at hand.
However, if the issue happens independent of the number of epochs or initialization, it may mean that there is something happening inside the fast non-deterministic iterations of the algorithm (maybe specific to smaller datasets?). cc @trivialfis as he might have insights here.

@kc-howe
Copy link
Author

kc-howe commented Aug 19, 2024

I'm using version 24.06, so that checks out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants