Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom_jvp to jax.numpy.ldexp. #23923

Merged
merged 1 commit into from
Sep 30, 2024

Conversation

carlosgmartin
Copy link
Contributor

Addresses #11467 (comment).

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 25, 2024

Thanks for looking into this! I'm not sure this is actually the correct gradient behavior for this function. As I mentioned in #11467 (comment), ldexp is a bit of a strange function, because it sort of computes x * 2 ** y, but doesn't really compute that, because it's operating on the details of the bitwise representation of x and y. Because of that, I think autodiff doesn't really apply to this function, because autodiff operates only in the space of real numbers being represented by the floating point implementation, so bitwise manipulations don't really have a gradient.

Maybe the best solution here would be to define custom_jvp that raises an error saying basically "this function isn't differentiable". What do you think?

@jakevdp jakevdp self-assigned this Sep 25, 2024
@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Sep 25, 2024

@jakevdp Thanks for your feedback.

It seems to me that, if a JAX function extensionally computes a differentiable mathematical function $f$ then, regardless of how it is implemented internally, it ought to have the same gradient.

Indeed, my understanding is that one of the primary use cases of custom_jvp is precisely to recover gradients of functions whose internal implementations are not auto-differentiable.

Since ldexp computes $f(x) = x 2^n$, the least surprising behavior from a user POV would be that $f'(x) = 2^n$. Indeed, a user (or the compiler) might encounter such an expression in a program and substitute it with ldexp as an optimization. It would seem surprising if this substitution were to cause an error on a backward pass.

Thoughts?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 25, 2024

So my question is, if we're just computing $x * 2 ^ y$, why don't we fix this by changing the implementation to

def ldexp(x, y):
  return x * 2 ** y

Then no custom JVP is necessary at all.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Sep 26, 2024

I think (the bit-twiddling implementation of) ldexp is supposed to be a faster way to do that, at least on some platforms.

The CUDA Math API has dedicated ldexp functions for single and double precision.

The C standard library also has dedicated ldexp functions.

Not sure what the performance advantage is for different platforms.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 26, 2024

Sure, but JAX does not dispatch to any of those fast kernels, and I imagine the current bit-twiddling implementation is far slower than just writing x1 * 2 ** x2. Is there any good reason to keep the current implementation?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 26, 2024

I guess stepping back, here are the options:

  1. ldexp is fundamentally a bit-twiddling operation. In this case, its autodiff behavior is poorly defined, and if we add a custom_vjp, it should probably just raise an error.
  2. ldexp represents the mathematical operation $x * 2^y$ with platform-specific implementation details when custom kernels are available. In this case, the optimal implementation in JAX would be to write x * 2 ** y, which would have the side effect of making custom_jvp unnecessary.

Until now, we've approached this as (1). Which do you think is the right approach?

@carlosgmartin
Copy link
Contributor Author

You raise an interesting point. If there's indeed no performance advantage to the bit-twiddling implementation of ldexp(x, n) over x * 2 ** n, then perhaps the former isn't necessary and can be replaced with the latter.

At least for the time being, perhaps it's worth adding a note to the documentation for ldexp stating that there's no performance advantage to its current bit-twiddling implementation over x * 2 ** n.

I'd also welcome any additional opinions from people who are more familiar with the hardware side of things.

jax/_src/numpy/ufuncs.py Outdated Show resolved Hide resolved
jax/_src/numpy/ufuncs.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

One other thing: we should update the function docs with info about the implementation (this would involve removing the @implements decorator and writing a full docstring).

If you'd like to do this as part of the PR then go ahead, but I'm happy to update docs in a followup.

@carlosgmartin
Copy link
Contributor Author

I'll let you handle that so you can choose the best wording.

@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Sep 30, 2024
jax/_src/numpy/ufuncs.py Outdated Show resolved Hide resolved
@copybara-service copybara-service bot merged commit 31cb3fd into jax-ml:main Sep 30, 2024
11 of 12 checks passed
@carlosgmartin carlosgmartin deleted the ldexp_custom_jvp branch September 30, 2024 23:23
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 30, 2024

Thanks for putting this together!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants