Skip to content

Commit

Permalink
Making TorchForce CUDA-graph aware (#103)
Browse files Browse the repository at this point in the history
* Add CUDA graph draft

* Initialize energy and force tensors in the GPU.

* Add comment on graph capture

* Catch torch exception if the model fails to capture.

* Replay graph just after construction
Finish capturing before rethrowing if an exception occurred during capture

* Add python-side test script for CUDA graphs

* Implement properties

* Update the Python bindings

* Unify the API for properties

* Pass the propery map to the constructor

* Skip graph tests if no GPU is present

* Guard CUDA graph behavior with the CUDA_GRAPH_ENABLE macro

* Check validity of the useCUDAGraphs property

* Add missing bracket to openmmtorch.i

* Fix bug in useCUDAgraph selection

* Update tests

* Add test for get/setProperty

* Update documentation with new functionality

* Add a CUDA graph test for a model that returns only energy

* Add contributors

* Reset pos grads after graph capture. Make energy and force tensors persistent.

* Add tests that execute the model many times to catch bugs related with
CUDA graph capture

* Run formatter

* Warmup model for several steps

* Include gradient reset into the graph

* Do not reset energy and force tensors before graph capture

* Remove unnecessary line

* Add tests for larger number of particles

* Remove unnecessary compilation guard now that Pytorch 1.10 is not supported

* Simplify getTensorPointer now that Pytorch 1.7 is not supported

* Change addForcesToOpenMM to addForces

* Change execute_graph to executeGraph

* Wrap graph warming up in a try/catch block

* Add correctness test for modules that only provide energy

* Revert "Add correctness test for modules that only provide energy"

This reverts commit d20f4bf.

* Explicit conversion to correct type in getTensorPointer

* Added a new property for TorchForce, CUDAGraphWarmupSteps.

* Clarify docs

* Document properties

* Throw if requested property does not exist

* Change getProperty(string) to getProperties()

* Add getProperties to python wrappers

* Fix formatting

* Set default properties

* Update tests

* Update some comments

---------

Co-authored-by: Raimondas Galvelis <r.galvelis@acellera.com>
  • Loading branch information
RaulPPelaez and Raimondas Galvelis committed Apr 21, 2023
1 parent f8055c3 commit b76deb4
Show file tree
Hide file tree
Showing 8 changed files with 340 additions and 86 deletions.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,44 @@ to return forces.
torch_force.setOutputsForces(True)
```

Recording the model into a CUDA graph
-------------------------------------

You can ask `TorchForce` to run the model using [CUDA graphs](https://pytorch.org/docs/stable/notes/cuda.html#cuda-graphs). Not every model will be compatible with this feature, but it can be a significant performance boost for some models. To enable it the CUDA platform must be used and an special property must be provided to `TorchForce`:

```python
torch_force.setProperty("useCUDAGraphs", "true")
# The property can also be set at construction
torch_force = TorchForce('model.pt', {'useCUDAGraphs': 'true'})
```

The first time the model is run, it will be compiled (also known as recording) into a CUDA graph. Subsequent runs will use the compiled graph, which can be significantly faster. It is possible that compilation fails, in which case an `OpenMMException` will be raised. If that happens, you can disable CUDA graphs and try again.

It is required to run the model at least once before recording, in what is known as warmup.
By default ```TorchForce``` will run the model just once before recording, but longer warming up might be desired. In these cases one can set the property ```CUDAGraphWarmupSteps```:
```python
torch_force.setProperty("CUDAGraphWarmupSteps", "12")
```

List of available properties
----------------------------

Some ```TorchForce``` functionalities can be customized by setting properties on an instance of it. Properties can be set at construction or by using ```setProperty```. A property is a pair of key/value strings. For instance:

```python
torch_force = TorchForce('model.pt', {'useCUDAGraphs': 'true'})
#Alternatively setProperty can be used to configure an already created instance.
#torch_force.setProperty("useCUDAGraphs", "true")
print("Current properties:")
for property in torch_force.getProperties():
print(property.key, property.value)
```

Currently, the following properties are available:

1. useCUDAGraphs: Turns on the CUDA graph functionality
2. CUDAGraphWarmupSteps: When CUDA graphs are being used, controls the number of warmup calls to the model before recording.

License
=======

Expand Down
26 changes: 22 additions & 4 deletions openmmapi/include/TorchForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
Expand All @@ -34,6 +34,7 @@

#include "openmm/Context.h"
#include "openmm/Force.h"
#include <map>
#include <string>
#include <torch/torch.h>
#include "internal/windowsExportTorch.h"
Expand All @@ -54,17 +55,20 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* Create a TorchForce. The network is defined by a PyTorch ScriptModule saved
* to a file.
*
* @param file the path to the file containing the network
* @param file the path to the file containing the network
* @param properties optional map of properties
*/
TorchForce(const std::string& file);
TorchForce(const std::string& file,
const std::map<std::string, std::string>& properties = {});
/**
* Create a TorchForce. The network is defined by a PyTorch ScriptModule
* Note that this constructor makes a copy of the provided module.
* Any changes to the module after calling this constructor will be ignored by TorchForce.
*
* @param module an instance of the torch module
* @param properties optional map of properties
*/
TorchForce(const torch::jit::Module &module);
TorchForce(const torch::jit::Module &module, const std::map<std::string, std::string>& properties = {});
/**
* Get the path to the file containing the network.
* If the TorchForce instance was constructed with a module, instead of a filename,
Expand Down Expand Up @@ -140,6 +144,18 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param defaultValue the default value of the parameter
*/
void setGlobalParameterDefaultValue(int index, double defaultValue);
/**
* Set a value of a property.
*
* @param name the name of the property
* @param value the value of the property
*/
void setProperty(const std::string& name, const std::string& value);
/**
* Get the map of properties for this instance.
* @return A map of property names to values.
*/
const std::map<std::string, std::string>& getProperties() const;
protected:
OpenMM::ForceImpl* createImpl() const;
private:
Expand All @@ -148,6 +164,8 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
bool usePeriodic, outputsForces;
std::vector<GlobalParameterInfo> globalParameters;
torch::jit::Module module;
std::map<std::string, std::string> properties;
std::string emptyProperty;
};

/**
Expand Down
25 changes: 21 additions & 4 deletions openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
Expand Down Expand Up @@ -41,10 +41,17 @@ using namespace TorchPlugin;
using namespace OpenMM;
using namespace std;

TorchForce::TorchForce(const torch::jit::Module& module) : file(), usePeriodic(false), outputsForces(false), module(module) {
TorchForce::TorchForce(const torch::jit::Module& module, const map<string, string>& properties) : file(), usePeriodic(false), outputsForces(false), module(module) {
const std::map<std::string, std::string> defaultProperties = {{"useCUDAGraphs", "false"}, {"CUDAGraphWarmupSteps", "1"}};
this->properties = defaultProperties;
for (auto& property : properties) {
if (defaultProperties.find(property.first) == defaultProperties.end())
throw OpenMMException("TorchForce: Unknown property '" + property.first + "'");
this->properties[property.first] = property.second;
}
}

TorchForce::TorchForce(const std::string& file) : TorchForce(torch::jit::load(file)) {
TorchForce::TorchForce(const std::string& file, const map<string, string>& properties) : TorchForce(torch::jit::load(file), properties) {
this->file = file;
}

Expand Down Expand Up @@ -78,7 +85,7 @@ bool TorchForce::getOutputsForces() const {

int TorchForce::addGlobalParameter(const string& name, double defaultValue) {
globalParameters.push_back(GlobalParameterInfo(name, defaultValue));
return globalParameters.size()-1;
return globalParameters.size() - 1;
}

int TorchForce::getNumGlobalParameters() const {
Expand All @@ -104,3 +111,13 @@ void TorchForce::setGlobalParameterDefaultValue(int index, double defaultValue)
ASSERT_VALID_INDEX(index, globalParameters);
globalParameters[index].defaultValue = defaultValue;
}

void TorchForce::setProperty(const std::string& name, const std::string& value) {
if (properties.find(name) == properties.end())
throw OpenMMException("TorchForce: Unknown property '" + name + "'");
properties[name] = value;
}

const std::map<std::string, std::string>& TorchForce::getProperties() const {
return properties;
}
Loading

0 comments on commit b76deb4

Please sign in to comment.