Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

figure out a better run constraint for jax #197

Open
beckermr opened this issue Sep 27, 2023 · 9 comments
Open

figure out a better run constraint for jax #197

beckermr opened this issue Sep 27, 2023 · 9 comments

Comments

@beckermr
Copy link
Member

Right now the jax run constraint is simply >={{ version }}. I've gotten reports of installation issues when jax gets too far ahead of jaxlib. Maybe we should pin to x.x or somethibg?

@JoanneBogart
Copy link

JoanneBogart commented Sep 29, 2023

This happened to me when I inadvertently was using an old Python, 3.5 I think. The jax version was way ahead of the jaxlib version but I couldn't update to a newer jaxlib version. The install completed; the problem showed up at runtime. When I set up another environment with Python 3.10 and started over the problem went away.

@michaelosthege
Copy link

We're having this problem in the CI pipelines of pymc-devs/pymc.

Very annoying, because we don't really want to pin exact jax/jaxlib versions

@beckermr
Copy link
Member Author

I think we should have jax pin to jaxlib at the same version or maybe be at most 2 minor versions ahead. Any thoughts @conda-forge/jax @conda-forge/jaxlib?

@hawkinsp
Copy link
Contributor

hawkinsp commented Oct 20, 2023

Upstream JAX maintainer: there's probably no harm in pinning the most recent jaxlib as a dependency of each jax release. It's possible we'll do that for the pip packages also soon.

The main reason we haven't so far is because there are multiple kinds of jaxlib in the pip release: CUDA and non-CUDA. There's no good way to say that with a constraint. However we're planning to split the upstream jaxlib package into jaxlib, which will become a hard dependency of jax, and a packages named something like jax-cuda11-plugin and jax-cuda12-plugin, which contain anything CUDA-specific. I don't know if you want to follow suit with the conda packaging; up to you.

However since you already have a method for handling CUDA/non-CUDA variants in the conda-forge build, I would imagine you can just add a hard constraint right now.

@beckermr
Copy link
Member Author

Thank you!

@ngam
Copy link
Contributor

ngam commented Oct 24, 2023

We usually track the pin in the Jax feedstock. A PR was just merged that will hopefully catch problems going forward. See conda-forge/jax-feedstock#130.

@beckermr let me know if that PR addresses your concerns or if we should come up with a better solution

@beckermr
Copy link
Member Author

That's great. It requires manual updates which sucks.

@ngam
Copy link
Contributor

ngam commented Oct 24, 2023

But it should (in theory) fail if not updated, so easier to spot... maybe?

@beckermr
Copy link
Member Author

I think we could put jaxlib in host and then pin greater than that version. That would work like a compiler where you always have to be at or after the version used for the build.

If we see errors, maybe we revisit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants