Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hangs encountered in IDAKLUJax unit tests (test_jacrev_vmap and others) #3948

Closed
agriyakhetarpal opened this issue Mar 31, 2024 · 10 comments
Closed
Labels
bug Something isn't working priority: medium To be resolved if time allows

Comments

@agriyakhetarpal
Copy link
Member

PyBaMM Version

develop

Python Version

3.11.8

Describe the bug

The test_jacrev_vmap test case in the TestIDAKLUJax class (present in tests/unit/test_solvers/test_idaklu_jax.py) hangs quite a lot during local development. It is one of the slowest tests to pass, to the point that coverage logging almost gets stuck indefinitely on 99% and that this test, in particular, takes time in several orders of magnitude more to complete when compared to the rest of the tests.

This is most likely coming from the recent migration to using pytest for running the unit tests (#3857), which also brought support for pytest-xdist for parallel execution of unit tests, where JAX-related unit tests take up a lot of time in CI in parallel mode.

Here's an SVG from a profiling sample with the pytest-profiling plugin from @prady0t earlier in the #infrastructure channel on Slack:

Expand to view

combined

which reveals that something is up with the JAX-related tests.

Steps to Reproduce

There isn't a better reproducer at this time, but to reproduce one can run nox -s coverage or its pytest --cov equivalent in the root directory – it is a bit slower than nox -s unit, but both of them seem to have the same issue.

Relevant log output

No response

@agriyakhetarpal agriyakhetarpal added bug Something isn't working priority: medium To be resolved if time allows labels Mar 31, 2024
@agriyakhetarpal
Copy link
Member Author

Temporary solution: wrap all classes and the methods with @pytest.mark.xdist_group(name="serial execution") to run them inside the same worker in serial mode.

This resolves the test_jacrev_vmap execution and it runs as normal – but test_jacrev_vector_getvars, and test_solver_ + test_solver_sensitivities from the JAX BDF solver are three other culprits which still aren't happy – they still take a lot longer than any other test. I assume we can extract more speedups with the newly enabled parallel testing, however, these test cases do not seem to budge and are causing bottlenecks. Maybe @jsbrittain would have some suggestions here?

@agriyakhetarpal
Copy link
Member Author

To bring some rudimentary sense of the issue at hand, here's what I can see locally on an M-series macOS machine:

Running the entire coverage suite with the JAX tests included:

1596 passed, 2 skipped in 663.93s (0:11:03)

And running it again, except test_idaklu_jax.py and test_jax_bdf_solver.py brings 98% of the test suite to completion at just 87 seconds!

@agriyakhetarpal agriyakhetarpal changed the title Hangs encountered in IDAKLUJax unit tests (test_jacrev_vmap) Hangs encountered in IDAKLUJax unit tests (test_jacrev_vmap and others) Mar 31, 2024
@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

I saw this on python 3.9 as well

@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

@agriyakhetarpal To me it looks like the functions in there are parallel. So parallel tests with parallel code means a ton of extra threads. On my Mac those tests seem to use 4 threads each

@cringeyburger
Copy link
Contributor

I want to help with this issue. Please let me know if you need another hand.

@agriyakhetarpal
Copy link
Member Author

Thanks, @prady0t, well, it's really just the attempt in #3948 (comment) that's helped with one out of the four tests, so we need to dig deeper since the tests are running in parallel which slows it down but also that they are really slow in serial execution too. I think this can be tackled at a later time since the tests pass, at least. I think you're doing #3940 with @lorenzofavaro and that is a higher-priority issue we need to tackle at this moment :)

@agriyakhetarpal
Copy link
Member Author

Oops, I just realised that I mentioned and I tagged the wrong person, I am sorry. By all means, please feel free to help out here, @cringeyburger!

@agriyakhetarpal
Copy link
Member Author

This was discussed in the GSoC meeting today for the pytest migration project and I don't seem to have this problem anymore with the recent update to jaxlib==0.4.27.

======================================================================= short test summary info ========================================================================
SKIPPED [1] tests/unit/test_expression_tree/test_operations/test_latexify.py:84: Only run for Linux
SKIPPED [1] tests/unit/test_solvers/test_idaklu_jax.py:91: Both IDAKLU and JAX are available
=================================================================== 1608 passed, 2 skipped in 56.63s ===================================================================
nox > Session unit was successful.

and the trio of test_jacrev_vmap, test_jacrev_vector_getvars, and test_solver_sensitivities runs without any hangs – neither of them seems to cause any troubles with my macOS machine. We should be in a position to close this for now, if someone else can confirm this by running a fresh nox -s unit, that would be great.

@kratman
Copy link
Contributor

kratman commented Jun 12, 2024

I don't think I have seen it recently either

@agriyakhetarpal
Copy link
Member Author

Perfect, closing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: medium To be resolved if time allows
Projects
None yet
Development

No branches or pull requests

3 participants