-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Making TorchForce CUDA-graph aware (#103)
* Add CUDA graph draft * Initialize energy and force tensors in the GPU. * Add comment on graph capture * Catch torch exception if the model fails to capture. * Replay graph just after construction Finish capturing before rethrowing if an exception occurred during capture * Add python-side test script for CUDA graphs * Implement properties * Update the Python bindings * Unify the API for properties * Pass the propery map to the constructor * Skip graph tests if no GPU is present * Guard CUDA graph behavior with the CUDA_GRAPH_ENABLE macro * Check validity of the useCUDAGraphs property * Add missing bracket to openmmtorch.i * Fix bug in useCUDAgraph selection * Update tests * Add test for get/setProperty * Update documentation with new functionality * Add a CUDA graph test for a model that returns only energy * Add contributors * Reset pos grads after graph capture. Make energy and force tensors persistent. * Add tests that execute the model many times to catch bugs related with CUDA graph capture * Run formatter * Warmup model for several steps * Include gradient reset into the graph * Do not reset energy and force tensors before graph capture * Remove unnecessary line * Add tests for larger number of particles * Remove unnecessary compilation guard now that Pytorch 1.10 is not supported * Simplify getTensorPointer now that Pytorch 1.7 is not supported * Change addForcesToOpenMM to addForces * Change execute_graph to executeGraph * Wrap graph warming up in a try/catch block * Add correctness test for modules that only provide energy * Revert "Add correctness test for modules that only provide energy" This reverts commit d20f4bf. * Explicit conversion to correct type in getTensorPointer * Added a new property for TorchForce, CUDAGraphWarmupSteps. * Clarify docs * Document properties * Throw if requested property does not exist * Change getProperty(string) to getProperties() * Add getProperties to python wrappers * Fix formatting * Set default properties * Update tests * Update some comments --------- Co-authored-by: Raimondas Galvelis <r.galvelis@acellera.com>
- Loading branch information
1 parent
f8055c3
commit b76deb4
Showing
8 changed files
with
340 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.