diff --git a/serialization/src/TorchForceProxy.cpp b/serialization/src/TorchForceProxy.cpp index 6d9d312..840eada 100644 --- a/serialization/src/TorchForceProxy.cpp +++ b/serialization/src/TorchForceProxy.cpp @@ -74,7 +74,7 @@ TorchForceProxy::TorchForceProxy() : SerializationProxy("TorchForce") { } void TorchForceProxy::serialize(const void* object, SerializationNode& node) const { - node.setIntProperty("version", 3); + node.setIntProperty("version", 4); const TorchForce& force = *reinterpret_cast(object); node.setStringProperty("file", force.getFile()); try { @@ -95,11 +95,14 @@ void TorchForceProxy::serialize(const void* object, SerializationNode& node) con SerializationNode& paramDerivs = node.createChildNode("ParameterDerivatives"); for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) paramDerivs.createChildNode("Parameter").setStringProperty("name", force.getEnergyParameterDerivativeName(i)); + SerializationNode& properties = node.createChildNode("Properties"); + for (auto& prop : force.getProperties()) + properties.createChildNode("Property").setStringProperty("name", prop.first).setStringProperty("value", prop.second); } void* TorchForceProxy::deserialize(const SerializationNode& node) const { int storedVersion = node.getIntProperty("version"); - if (storedVersion > 3) + if (storedVersion > 4) throw OpenMMException("Unsupported version number"); TorchForce* force; if (storedVersion == 1) { @@ -126,6 +129,9 @@ void* TorchForceProxy::deserialize(const SerializationNode& node) const { if (child.getName() == "ParameterDerivatives") for (auto& parameter : child.getChildren()) force->addEnergyParameterDerivative(parameter.getStringProperty("name")); + if (child.getName() == "Properties") + for (auto& property : child.getChildren()) + force->setProperty(property.getStringProperty("name"), property.getStringProperty("value")); } return force; } diff --git a/serialization/tests/TestSerializeTorchForce.cpp b/serialization/tests/TestSerializeTorchForce.cpp index 410f00a..98753ee 100644 --- a/serialization/tests/TestSerializeTorchForce.cpp +++ b/serialization/tests/TestSerializeTorchForce.cpp @@ -51,6 +51,8 @@ void serializeAndDeserialize(TorchForce force) { force.setUsesPeriodicBoundaryConditions(true); force.setOutputsForces(true); force.addEnergyParameterDerivative("y"); + force.setProperty("useCUDAGraphs", "true"); + force.setProperty("CUDAGraphWarmupSteps", "5"); // Serialize and then deserialize it. @@ -77,6 +79,9 @@ void serializeAndDeserialize(TorchForce force) { ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives()); for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i)); + ASSERT_EQUAL(force.getProperties().size(), force2.getProperties().size()); + for (auto& prop : force.getProperties()) + ASSERT_EQUAL(prop.second, force2.getProperties().at(prop.first)); } void testSerializationFromModule() {