From b76deb4df8661713925a6254613823df8601c9dc Mon Sep 17 00:00:00 2001 From: Raul Date: Fri, 21 Apr 2023 13:52:39 +0100 Subject: [PATCH] Making TorchForce CUDA-graph aware (#103) * Add CUDA graph draft * Initialize energy and force tensors in the GPU. * Add comment on graph capture * Catch torch exception if the model fails to capture. * Replay graph just after construction Finish capturing before rethrowing if an exception occurred during capture * Add python-side test script for CUDA graphs * Implement properties * Update the Python bindings * Unify the API for properties * Pass the propery map to the constructor * Skip graph tests if no GPU is present * Guard CUDA graph behavior with the CUDA_GRAPH_ENABLE macro * Check validity of the useCUDAGraphs property * Add missing bracket to openmmtorch.i * Fix bug in useCUDAgraph selection * Update tests * Add test for get/setProperty * Update documentation with new functionality * Add a CUDA graph test for a model that returns only energy * Add contributors * Reset pos grads after graph capture. Make energy and force tensors persistent. * Add tests that execute the model many times to catch bugs related with CUDA graph capture * Run formatter * Warmup model for several steps * Include gradient reset into the graph * Do not reset energy and force tensors before graph capture * Remove unnecessary line * Add tests for larger number of particles * Remove unnecessary compilation guard now that Pytorch 1.10 is not supported * Simplify getTensorPointer now that Pytorch 1.7 is not supported * Change addForcesToOpenMM to addForces * Change execute_graph to executeGraph * Wrap graph warming up in a try/catch block * Add correctness test for modules that only provide energy * Revert "Add correctness test for modules that only provide energy" This reverts commit d20f4bfa83831e72a0f6a0f9baf650ac15f174f4. * Explicit conversion to correct type in getTensorPointer * Added a new property for TorchForce, CUDAGraphWarmupSteps. * Clarify docs * Document properties * Throw if requested property does not exist * Change getProperty(string) to getProperties() * Add getProperties to python wrappers * Fix formatting * Set default properties * Update tests * Update some comments --------- Co-authored-by: Raimondas Galvelis --- README.md | 38 +++++ openmmapi/include/TorchForce.h | 26 ++- openmmapi/src/TorchForce.cpp | 25 ++- platforms/cuda/src/CudaTorchKernels.cpp | 203 +++++++++++++++--------- platforms/cuda/src/CudaTorchKernels.h | 14 +- python/openmmtorch.i | 11 +- python/tests/TestCUDAGraphs.py | 93 +++++++++++ python/tests/TestTorchForce.py | 16 +- 8 files changed, 340 insertions(+), 86 deletions(-) create mode 100644 python/tests/TestCUDAGraphs.py diff --git a/README.md b/README.md index 08e752c0..e9ebe29a 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,44 @@ to return forces. torch_force.setOutputsForces(True) ``` +Recording the model into a CUDA graph +------------------------------------- + +You can ask `TorchForce` to run the model using [CUDA graphs](https://pytorch.org/docs/stable/notes/cuda.html#cuda-graphs). Not every model will be compatible with this feature, but it can be a significant performance boost for some models. To enable it the CUDA platform must be used and an special property must be provided to `TorchForce`: + +```python +torch_force.setProperty("useCUDAGraphs", "true") +# The property can also be set at construction +torch_force = TorchForce('model.pt', {'useCUDAGraphs': 'true'}) +``` + +The first time the model is run, it will be compiled (also known as recording) into a CUDA graph. Subsequent runs will use the compiled graph, which can be significantly faster. It is possible that compilation fails, in which case an `OpenMMException` will be raised. If that happens, you can disable CUDA graphs and try again. + +It is required to run the model at least once before recording, in what is known as warmup. +By default ```TorchForce``` will run the model just once before recording, but longer warming up might be desired. In these cases one can set the property ```CUDAGraphWarmupSteps```: +```python +torch_force.setProperty("CUDAGraphWarmupSteps", "12") +``` + +List of available properties +---------------------------- + +Some ```TorchForce``` functionalities can be customized by setting properties on an instance of it. Properties can be set at construction or by using ```setProperty```. A property is a pair of key/value strings. For instance: + +```python +torch_force = TorchForce('model.pt', {'useCUDAGraphs': 'true'}) +#Alternatively setProperty can be used to configure an already created instance. +#torch_force.setProperty("useCUDAGraphs", "true") +print("Current properties:") +for property in torch_force.getProperties(): + print(property.key, property.value) +``` + +Currently, the following properties are available: + +1. useCUDAGraphs: Turns on the CUDA graph functionality +2. CUDAGraphWarmupSteps: When CUDA graphs are being used, controls the number of warmup calls to the model before recording. + License ======= diff --git a/openmmapi/include/TorchForce.h b/openmmapi/include/TorchForce.h index 46e28193..4406eb20 100644 --- a/openmmapi/include/TorchForce.h +++ b/openmmapi/include/TorchForce.h @@ -11,7 +11,7 @@ * * * Portions copyright (c) 2018-2022 Stanford University and the Authors. * * Authors: Peter Eastman * - * Contributors: * + * Contributors: Raimondas Galvelis, Raul P. Pelaez * * * * Permission is hereby granted, free of charge, to any person obtaining a * * copy of this software and associated documentation files (the "Software"), * @@ -34,6 +34,7 @@ #include "openmm/Context.h" #include "openmm/Force.h" +#include #include #include #include "internal/windowsExportTorch.h" @@ -54,17 +55,20 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * Create a TorchForce. The network is defined by a PyTorch ScriptModule saved * to a file. * - * @param file the path to the file containing the network + * @param file the path to the file containing the network + * @param properties optional map of properties */ - TorchForce(const std::string& file); + TorchForce(const std::string& file, + const std::map& properties = {}); /** * Create a TorchForce. The network is defined by a PyTorch ScriptModule * Note that this constructor makes a copy of the provided module. * Any changes to the module after calling this constructor will be ignored by TorchForce. * * @param module an instance of the torch module + * @param properties optional map of properties */ - TorchForce(const torch::jit::Module &module); + TorchForce(const torch::jit::Module &module, const std::map& properties = {}); /** * Get the path to the file containing the network. * If the TorchForce instance was constructed with a module, instead of a filename, @@ -140,6 +144,18 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @param defaultValue the default value of the parameter */ void setGlobalParameterDefaultValue(int index, double defaultValue); + /** + * Set a value of a property. + * + * @param name the name of the property + * @param value the value of the property + */ + void setProperty(const std::string& name, const std::string& value); + /** + * Get the map of properties for this instance. + * @return A map of property names to values. + */ + const std::map& getProperties() const; protected: OpenMM::ForceImpl* createImpl() const; private: @@ -148,6 +164,8 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { bool usePeriodic, outputsForces; std::vector globalParameters; torch::jit::Module module; + std::map properties; + std::string emptyProperty; }; /** diff --git a/openmmapi/src/TorchForce.cpp b/openmmapi/src/TorchForce.cpp index b9aa64d6..41aa1fb2 100644 --- a/openmmapi/src/TorchForce.cpp +++ b/openmmapi/src/TorchForce.cpp @@ -8,7 +8,7 @@ * * * Portions copyright (c) 2018-2022 Stanford University and the Authors. * * Authors: Peter Eastman * - * Contributors: * + * Contributors: Raimondas Galvelis, Raul P. Pelaez * * * * Permission is hereby granted, free of charge, to any person obtaining a * * copy of this software and associated documentation files (the "Software"), * @@ -41,10 +41,17 @@ using namespace TorchPlugin; using namespace OpenMM; using namespace std; -TorchForce::TorchForce(const torch::jit::Module& module) : file(), usePeriodic(false), outputsForces(false), module(module) { +TorchForce::TorchForce(const torch::jit::Module& module, const map& properties) : file(), usePeriodic(false), outputsForces(false), module(module) { + const std::map defaultProperties = {{"useCUDAGraphs", "false"}, {"CUDAGraphWarmupSteps", "1"}}; + this->properties = defaultProperties; + for (auto& property : properties) { + if (defaultProperties.find(property.first) == defaultProperties.end()) + throw OpenMMException("TorchForce: Unknown property '" + property.first + "'"); + this->properties[property.first] = property.second; + } } -TorchForce::TorchForce(const std::string& file) : TorchForce(torch::jit::load(file)) { +TorchForce::TorchForce(const std::string& file, const map& properties) : TorchForce(torch::jit::load(file), properties) { this->file = file; } @@ -78,7 +85,7 @@ bool TorchForce::getOutputsForces() const { int TorchForce::addGlobalParameter(const string& name, double defaultValue) { globalParameters.push_back(GlobalParameterInfo(name, defaultValue)); - return globalParameters.size()-1; + return globalParameters.size() - 1; } int TorchForce::getNumGlobalParameters() const { @@ -104,3 +111,13 @@ void TorchForce::setGlobalParameterDefaultValue(int index, double defaultValue) ASSERT_VALID_INDEX(index, globalParameters); globalParameters[index].defaultValue = defaultValue; } + +void TorchForce::setProperty(const std::string& name, const std::string& value) { + if (properties.find(name) == properties.end()) + throw OpenMMException("TorchForce: Unknown property '" + name + "'"); + properties[name] = value; +} + +const std::map& TorchForce::getProperties() const { + return properties; +} diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 42e23d2c..d2d869ae 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -8,7 +8,7 @@ * * * Portions copyright (c) 2018-2022 Stanford University and the Authors. * * Authors: Peter Eastman * - * Contributors: * + * Contributors: Raimondas Galvelis, Raul P. Pelaez * * * * Permission is hereby granted, free of charge, to any person obtaining a * * copy of this software and associated documentation files (the "Software"), * @@ -35,22 +35,23 @@ #include "openmm/internal/ContextImpl.h" #include #include - +#include +#include using namespace TorchPlugin; using namespace OpenMM; using namespace std; // macro for checking the result of synchronization operation on CUDA // copied from `openmm/platforms/cuda/src/CudaParallelKernels.cpp` -#define CHECK_RESULT(result, prefix) \ -if (result != CUDA_SUCCESS) { \ - std::stringstream m; \ - m<warmupSteps = std::stoi(properties["CUDAGraphWarmupSteps"]); + if (this->warmupSteps <= 0) { + throw OpenMMException("TorchForce: \"CUDAGraphWarmupSteps\" must be a positive integer"); + } + } } -double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { - int numParticles = cu.getNumAtoms(); - - // Push to the PyTorch context - CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context"); - - // Get pointers to the atomic positions and simulation box - void* posData; - void* boxData; +/** + * Get a pointer to the data in a PyTorch tensor. + * The tensor is converted to the correct data type if necessary. + */ +static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) { + void* data; if (cu.getUseDoublePrecision()) { - posData = posTensor.data_ptr(); - boxData = boxTensor.data_ptr(); - } - else { - posData = posTensor.data_ptr(); - boxData = boxTensor.data_ptr(); + data = tensor.to(torch::kFloat64).data_ptr(); + } else { + data = tensor.to(torch::kFloat32).data_ptr(); } + return data; +} +/** + * Prepare the inputs for the PyTorch model, copying positions from the OpenMM context. + */ +std::vector CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context) { + int numParticles = cu.getNumAtoms(); + // Get pointers to the atomic positions and simulation box + void* posData = getTensorPointer(cu, posTensor); + void* boxData = getTensorPointer(cu, boxTensor); // Copy the atomic positions and simulation box to PyTorch tensors { ContextSelector selector(cu); // Switch to the OpenMM context - void* inputArgs[] = {&posData, &boxData, &cu.getPosq().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), - &numParticles, cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer()}; + void* inputArgs[] = {&posData, + &boxData, + &cu.getPosq().getDevicePointer(), + &cu.getAtomIndexArray().getDevicePointer(), + &numParticles, + cu.getPeriodicBoxVecXPointer(), + cu.getPeriodicBoxVecYPointer(), + cu.getPeriodicBoxVecZPointer()}; cu.executeKernel(copyInputsKernel, inputArgs, numParticles); CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context } - // Prepare the input of the PyTorch model vector inputs = {posTensor}; if (usePeriodic) inputs.push_back(boxTensor); for (const string& name : globalNames) inputs.push_back(torch::tensor(context.getParameter(name))); + return inputs; +} + +/** + * Add the computed forces to the total atomic forces. + */ +void CudaCalcTorchForceKernel::addForces(torch::Tensor& forceTensor) { + int numParticles = cu.getNumAtoms(); + // Get a pointer to the computed forces + void* forceData = getTensorPointer(cu, forceTensor); + CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the OpenMM context + // Add the computed forces to the total atomic forces + { + ContextSelector selector(cu); // Switch to the OpenMM context + int paddedNumAtoms = cu.getPaddedNumAtoms(); + int forceSign = (outputsForces ? 1 : -1); + void* forceArgs[] = {&forceData, &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms, &forceSign}; + cu.executeKernel(addForcesKernel, forceArgs, numParticles); + CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context + } +} - // Execute the PyTorch model - torch::Tensor energyTensor, forceTensor; +/** + * This function launches the workload in a way compatible with CUDA + * graphs as far as OpenMM-Torch goes. Capturing this function when + * the model is not itself graph compatible (due to, for instance, + * implicit synchronizations) will result in a CUDA error. + */ +static void executeGraph(bool outputsForces, bool includeForces, torch::jit::script::Module& module, vector& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor, + torch::Tensor& forceTensor) { if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); energyTensor = outputs->elements()[0].toTensor(); forceTensor = outputs->elements()[1].toTensor(); - } - else + } else { energyTensor = module.forward(inputs).toTensor(); - - if (includeForces) { - - // Compute force by backprogating the PyTorch model - if (!outputsForces) { + // Compute force by backpropagating the PyTorch model + if (includeForces) { energyTensor.backward(); - forceTensor = posTensor.grad(); - } - - // Get a pointer to the computed forces - void* forceData; - if (cu.getUseDoublePrecision()) { - if (!(forceTensor.dtype() == torch::kFloat64)) // TODO: simplify the logic when support for PyTorch 1.7 is dropped - forceTensor = forceTensor.to(torch::kFloat64); - forceData = forceTensor.data_ptr(); - } - else { - if (!(forceTensor.dtype() == torch::kFloat32)) // TODO: simplify the logic when support for PyTorch 1.7 is dropped - forceTensor = forceTensor.to(torch::kFloat32); - forceData = forceTensor.data_ptr(); + forceTensor = posTensor.grad().clone(); + // Zero the gradient to avoid accumulating it + posTensor.grad().zero_(); } - CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the OpenMM context + } +} - // Add the computed forces to the total atomic forces - { - ContextSelector selector(cu); // Switch to the OpenMM context - int paddedNumAtoms = cu.getPaddedNumAtoms(); - int forceSign = (outputsForces ? 1 : -1); - void* forceArgs[] = {&forceData, &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms, &forceSign}; - cu.executeKernel(addForcesKernel, forceArgs, numParticles); - CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context +double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { + // Push to the PyTorch context + CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context"); + auto inputs = prepareTorchInputs(context); + if (!useGraphs) { + executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + } else { + const auto stream = c10::cuda::getStreamFromPool(false, posTensor.get_device()); + const c10::cuda::CUDAStreamGuard guard(stream); + // Record graph if not already done + bool is_graph_captured = false; + if (graphs.find(includeForces) == graphs.end()) { + // Warmup the graph workload before capturing. This first + // run before capture sets up allocations so that no + // allocations are needed after. Pytorch's allocator is + // stream capture-aware and, after warmup, will provide + // record static pointers and shapes during capture. + try { + for (int i = 0; i < this->warmupSteps; i++) + executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + } + catch (std::exception& e) { + throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what()); + } + graphs[includeForces].capture_begin(); + try { + executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + is_graph_captured = true; + graphs[includeForces].capture_end(); + } + catch (std::exception& e) { + if (!is_graph_captured) { + graphs[includeForces].capture_end(); + } + throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what()); + } } - - // Reset the forces - if (!outputsForces) - posTensor.grad().zero_(); + graphs[includeForces].replay(); + } + if (includeForces) { + addForces(forceTensor); } - // Get energy const double energy = energyTensor.item(); // This implicitly synchronizes the PyTorch context - // Pop to the PyTorch context CUcontext ctx; CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context"); assert(primaryContext == ctx); // Check that the correct context was popped - return energy; } diff --git a/platforms/cuda/src/CudaTorchKernels.h b/platforms/cuda/src/CudaTorchKernels.h index 23a6238c..13f2a9b6 100644 --- a/platforms/cuda/src/CudaTorchKernels.h +++ b/platforms/cuda/src/CudaTorchKernels.h @@ -11,7 +11,7 @@ * * * Portions copyright (c) 2018-2022 Stanford University and the Authors. * * Authors: Peter Eastman * - * Contributors: * + * Contributors: Raimondas Galvelis, Raul P. Pelaez * * * * Permission is hereby granted, free of charge, to any person obtaining a * * copy of this software and associated documentation files (the "Software"), * @@ -34,6 +34,9 @@ #include "TorchKernels.h" #include "openmm/cuda/CudaContext.h" +#include "openmm/cuda/CudaArray.h" +#include +#include namespace TorchPlugin { @@ -46,7 +49,7 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel { ~CudaCalcTorchForceKernel(); /** * Initialize the kernel. - * + * * @param system the System this kernel will be applied to * @param force the TorchForce this kernel will be used for * @param module the PyTorch module to use for computing forces and energy @@ -61,15 +64,22 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel { * @return the potential energy due to the force */ double execute(OpenMM::ContextImpl& context, bool includeForces, bool includeEnergy); + private: bool hasInitializedKernel; OpenMM::CudaContext& cu; torch::jit::script::Module module; torch::Tensor posTensor, boxTensor; + torch::Tensor energyTensor, forceTensor; std::vector globalNames; bool usePeriodic, outputsForces; CUfunction copyInputsKernel, addForcesKernel; CUcontext primaryContext; + std::map graphs; + std::vector prepareTorchInputs(OpenMM::ContextImpl& context); + bool useGraphs; + void addForces(torch::Tensor& forceTensor); + int warmupSteps; }; } // namespace TorchPlugin diff --git a/python/openmmtorch.i b/python/openmmtorch.i index bfabe86e..94880522 100644 --- a/python/openmmtorch.i +++ b/python/openmmtorch.i @@ -4,6 +4,7 @@ %import(module="simtk.openmm") "swig/OpenMMSwigHeaders.i" %include "swig/typemaps.i" %include +%include %{ #include "TorchForce.h" @@ -46,12 +47,16 @@ $1 = torch::jit::as_module(o).has_value() ? 1 : 0; } +namespace std { + %template(property_map) map; +} + namespace TorchPlugin { class TorchForce : public OpenMM::Force { public: - TorchForce(const std::string& file); - TorchForce(const torch::jit::Module& module); + TorchForce(const std::string& file, const std::map& properties = {}); + TorchForce(const torch::jit::Module& module, const std::map& properties = {}); const std::string& getFile() const; const torch::jit::Module& getModule() const; void setUsesPeriodicBoundaryConditions(bool periodic); @@ -64,6 +69,8 @@ public: void setGlobalParameterName(int index, const std::string& name); double getGlobalParameterDefaultValue(int index) const; void setGlobalParameterDefaultValue(int index, double defaultValue); + void setProperty(const std::string& name, const std::string& value); + const std::map& getProperties() const; /* * Add methods for casting a Force to a TorchForce. diff --git a/python/tests/TestCUDAGraphs.py b/python/tests/TestCUDAGraphs.py new file mode 100644 index 00000000..af1e5c8b --- /dev/null +++ b/python/tests/TestCUDAGraphs.py @@ -0,0 +1,93 @@ +__author__ = "Raul P. Pelaez" +import openmmtorch as ot +import torch +import openmm as mm +import numpy as np +import pytest + + +class UngraphableModule(torch.nn.Module): + def forward(self, positions): + torch.cuda.synchronize() + return (torch.sum(positions**2), -2.0 * positions) + + +class GraphableModule(torch.nn.Module): + def forward(self, positions): + energy = torch.einsum("ij,ij->i", positions, positions).sum() + return (energy, -2.0 * positions) + + +class GraphableModuleOnlyEnergy(torch.nn.Module): + def forward(self, positions): + energy = torch.einsum("ij,ij->i", positions, positions).sum() + return energy + + +def tryToTestForceWithModule( + ModuleType, outputsForce, useGraphs=False, warmup=10, numParticles=10 +): + """Test that the force is correctly computed for a given module type. + Warmup makes OpenMM call TorchForce execution multiple times, which might expose some bugs related to that given that with CUDA graphs the first execution is different from the rest. + """ + module = torch.jit.script(ModuleType()) + torch_force = ot.TorchForce( + module, {"useCUDAGraphs": "true" if useGraphs else "false"} + ) + torch_force.setOutputsForces(outputsForce) + system = mm.System() + positions = np.random.rand(numParticles, 3) + for _ in range(numParticles): + system.addParticle(1.0) + system.addForce(torch_force) + integ = mm.VerletIntegrator(1.0) + platform = mm.Platform.getPlatformByName("CUDA") + context = mm.Context(system, integ, platform) + context.setPositions(positions) + for _ in range(warmup): + state = context.getState(getEnergy=True, getForces=True) + expectedEnergy = np.sum(positions**2) + expectedForce = -2.0 * positions + energy = state.getPotentialEnergy().value_in_unit(mm.unit.kilojoules_per_mole) + force = state.getForces(asNumpy=True).value_in_unit( + mm.unit.kilojoules_per_mole / mm.unit.nanometer + ) + assert np.allclose(expectedEnergy, energy) + assert np.allclose(expectedForce, force) + + +def testUnGraphableModelRaises(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + with pytest.raises(mm.OpenMMException): + tryToTestForceWithModule(UngraphableModule, outputsForce=True, useGraphs=True) + + +@pytest.mark.parametrize("numParticles", [10, 10000]) +@pytest.mark.parametrize("useGraphs", [True, False]) +@pytest.mark.parametrize("warmup", [1, 10]) +def testGraphableModelOnlyEnergyIsCorrect(useGraphs, warmup, numParticles): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + tryToTestForceWithModule( + GraphableModuleOnlyEnergy, + outputsForce=False, + useGraphs=useGraphs, + warmup=warmup, + numParticles=numParticles, + ) + + +@pytest.mark.parametrize("numParticles", [10, 10000]) +@pytest.mark.parametrize("useGraphs", [True, False]) +@pytest.mark.parametrize("warmup", [1, 10]) +def testGraphableModelIsCorrect(useGraphs, warmup, numParticles): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + tryToTestForceWithModule( + GraphableModule, + outputsForce=True, + useGraphs=useGraphs, + warmup=warmup, + numParticles=numParticles, + ) diff --git a/python/tests/TestTorchForce.py b/python/tests/TestTorchForce.py index 332a6cf4..ed358722 100644 --- a/python/tests/TestTorchForce.py +++ b/python/tests/TestTorchForce.py @@ -37,12 +37,13 @@ def testForce(model_file, output_forces, use_module_constructor, use_cv_force, p # Create a force if use_module_constructor: model = pt.jit.load(model_file) - force = ot.TorchForce(model) + force = ot.TorchForce(model, {'useCUDAGraphs': 'false'}) else: - force = ot.TorchForce(model_file) + force = ot.TorchForce(model_file, {'useCUDAGraphs': 'false'}) assert not force.getOutputsForces() # Check the default force.setOutputsForces(output_forces) assert force.getOutputsForces() == output_forces + assert force.getProperties()['useCUDAGraphs'] == 'false' if use_cv_force: # Wrap TorchForce into CustomCVForce cv_force = mm.CustomCVForce('force') @@ -114,4 +115,13 @@ def forward(self, positions): context = mm.Context(system, integrator, platform, properties) context.setPositions(positions) - context.getState(getEnergy=True, getForces=True) \ No newline at end of file + context.getState(getEnergy=True, getForces=True) + + +def testProperties(): + """ Test that the properties are correctly set and retrieved """ + force = ot.TorchForce('../../tests/central.pt') + force.setProperty('useCUDAGraphs', 'true') + assert force.getProperties()['useCUDAGraphs'] == 'true' + force.setProperty('useCUDAGraphs', 'false') + assert force.getProperties()['useCUDAGraphs'] == 'false'