Skip to content

Commit

Permalink
pdlp: export from google3
Browse files Browse the repository at this point in the history
  • Loading branch information
Mizux committed Jul 13, 2023
1 parent a29d1ee commit bad3f17
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 69 deletions.
7 changes: 4 additions & 3 deletions ortools/pdlp/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ cc_library(
":trust_region",
"//ortools/base",
"//ortools/base:mathutil",
"//ortools/base:status_macros",
"//ortools/base:timer",
"//ortools/glop:parameters_cc_proto",
"//ortools/glop:preprocessor",
Expand Down Expand Up @@ -151,7 +150,6 @@ cc_test(
"//ortools/lp_data",
"//ortools/lp_data:base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@eigen//:eigen3",
Expand Down Expand Up @@ -182,6 +180,7 @@ cc_test(
":quadratic_program",
":test_util",
"//ortools/base:protobuf_util",
"//ortools/base:status_macros",
"//ortools/linear_solver:linear_solver_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand All @@ -200,13 +199,14 @@ cc_library(
"//ortools/base:status_macros",
"//ortools/linear_solver:linear_solver_cc_proto",
"//ortools/linear_solver:model_exporter",
"//ortools/lp_data:mps_reader",
"//ortools/lp_data:mps_reader_template",
"//ortools/util:file_util",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@eigen//:eigen3",
],
)
Expand Down Expand Up @@ -295,6 +295,7 @@ cc_test(
":sharder",
"//ortools/base",
"//ortools/base:mathutil",
"//ortools/base:threadpool",
"@com_google_absl//absl/random:distributions",
"@eigen//:eigen3",
],
Expand Down
1 change: 1 addition & 0 deletions ortools/pdlp/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pybind_extension(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@eigen//:eigen3",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
],
)

Expand Down
52 changes: 12 additions & 40 deletions ortools/pdlp/python/pdlp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "ortools/linear_solver/linear_solver.pb.h"
#include "ortools/pdlp/primal_dual_hybrid_gradient.h"
#include "ortools/pdlp/quadratic_program.h"
Expand All @@ -32,28 +31,14 @@
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
#include "pybind11_protobuf/native_proto_caster.h"

namespace operations_research::pdlp {
namespace {

using ::pybind11::arg;

// TODO(user): The interface uses serialized protos because of issues building
// pybind11_protobuf. See
// https://github.com/protocolbuffers/protobuf/issues/9464. After
// pybind11_protobuf is working, this workaround can be removed.

// A mirror of pdlp::SolverResult except with a serialized SolveLog.
struct PywrapSolverResult {
Eigen::VectorXd primal_solution;
Eigen::VectorXd dual_solution;
Eigen::VectorXd reduced_costs;
pybind11::bytes solve_log_str;
};

} // namespace

PYBIND11_MODULE(pdlp, m) {
pybind11_protobuf::ImportNativeProtoCasters();
// ---------------------------------------------------------------------------
// quadratic_program.h
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -110,12 +95,8 @@ PYBIND11_MODULE(pdlp, m) {

m.def(
"qp_from_mpmodel_proto",
[](absl::string_view proto_str, bool relax_integer_variables,
[](const MPModelProto& proto, bool relax_integer_variables,
bool include_names) {
MPModelProto proto;
if (!proto.ParseFromString(std::string(proto_str))) {
throw std::invalid_argument("Unable to parse input proto");
}
absl::StatusOr<QuadraticProgram> qp =
QpFromMpModelProto(proto, relax_integer_variables, include_names);
if (qp.ok()) {
Expand All @@ -130,7 +111,7 @@ PYBIND11_MODULE(pdlp, m) {
m.def("qp_to_mpmodel_proto", [](const QuadraticProgram& qp) {
absl::StatusOr<MPModelProto> proto = QpToMpModelProto(qp);
if (proto.ok()) {
return pybind11::bytes(proto->SerializeAsString());
return *proto;
} else {
throw std::invalid_argument(absl::StrCat(proto.status().message()));
}
Expand All @@ -152,29 +133,20 @@ PYBIND11_MODULE(pdlp, m) {
.def_readwrite("primal_solution", &PrimalAndDualSolution::primal_solution)
.def_readwrite("dual_solution", &PrimalAndDualSolution::dual_solution);

pybind11::class_<PywrapSolverResult>(m, "SolverResult")
pybind11::class_<SolverResult>(m, "SolverResult")
.def(pybind11::init<>())
.def_readwrite("primal_solution", &PywrapSolverResult::primal_solution)
.def_readwrite("dual_solution", &PywrapSolverResult::dual_solution)
.def_readwrite("reduced_costs", &PywrapSolverResult::reduced_costs)
.def_readwrite("solve_log_str", &PywrapSolverResult::solve_log_str);
.def_readwrite("primal_solution", &SolverResult::primal_solution)
.def_readwrite("dual_solution", &SolverResult::dual_solution)
.def_readwrite("reduced_costs", &SolverResult::reduced_costs)
.def_readwrite("solve_log", &SolverResult::solve_log);

// TODO(user): Expose interrupt_solve and iteration_stats_callback.
m.def(
"primal_dual_hybrid_gradient",
[](QuadraticProgram qp, absl::string_view params_str,
[](QuadraticProgram qp, PrimalDualHybridGradientParams params,
std::optional<PrimalAndDualSolution> initial_solution) {
PrimalDualHybridGradientParams params;
if (!params.ParseFromString(std::string(params_str))) {
throw std::invalid_argument("Unable to parse input params");
}
SolverResult result = PrimalDualHybridGradient(
std::move(qp), params, std::move(initial_solution));
return PywrapSolverResult{
.primal_solution = std::move(result.primal_solution),
.dual_solution = std::move(result.dual_solution),
.reduced_costs = std::move(result.reduced_costs),
.solve_log_str = result.solve_log.SerializeAsString()};
return PrimalDualHybridGradient(std::move(qp), params,
std::move(initial_solution));
},
arg("qp"), arg("params"), arg("initial_solution") = std::nullopt);
}
Expand Down
39 changes: 17 additions & 22 deletions ortools/pdlp/python/pdlp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,14 @@ def test_validate_quadratic_program_dimensions_for_empty_qp(self):

def test_converts_from_tiny_mpmodel_lp(self):
lp_proto = small_proto_lp()
qp = pdlp.qp_from_mpmodel_proto(
lp_proto.SerializeToString(), relax_integer_variables=False
)
qp = pdlp.qp_from_mpmodel_proto(lp_proto, relax_integer_variables=False)
pdlp.validate_quadratic_program_dimensions(qp)
self.assertTrue(pdlp.is_linear_program(qp))
self.assertSameElements(qp.objective_vector, [0, -2])

def test_converts_from_tiny_mpmodel_qp(self):
qp_proto = small_proto_qp()
qp = pdlp.qp_from_mpmodel_proto(
qp_proto.SerializeToString(), relax_integer_variables=False
)
qp = pdlp.qp_from_mpmodel_proto(qp_proto, relax_integer_variables=False)
pdlp.validate_quadratic_program_dimensions(qp)
self.assertFalse(pdlp.is_linear_program(qp))
self.assertSameElements(qp.objective_vector, [0, 0])
Expand All @@ -110,7 +106,7 @@ def test_build_lp(self):
qp.variable_upper_bounds = [np.inf, np.inf]
qp.variable_names = ["x", "y"]
self.assertEqual(
linear_solver_pb2.MPModelProto.FromString(pdlp.qp_to_mpmodel_proto(qp)),
pdlp.qp_to_mpmodel_proto(qp),
small_proto_lp(),
)

Expand All @@ -125,7 +121,7 @@ def test_build_qp(self):
qp.variable_upper_bounds = [np.inf, np.inf]
qp.variable_names = ["x", "y"]
self.assertEqual(
linear_solver_pb2.MPModelProto.FromString(pdlp.qp_to_mpmodel_proto(qp)),
pdlp.qp_to_mpmodel_proto(qp),
small_proto_qp(),
)

Expand Down Expand Up @@ -197,11 +193,10 @@ def test_iteration_limit(self):
params = solvers_pb2.PrimalDualHybridGradientParams()
params.termination_criteria.iteration_limit = 1
params.termination_check_frequency = 1
result = pdlp.primal_dual_hybrid_gradient(tiny_lp(), params.SerializeToString())
solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str)
self.assertLessEqual(solve_log.iteration_count, 1)
result = pdlp.primal_dual_hybrid_gradient(tiny_lp(), params)
self.assertLessEqual(result.solve_log.iteration_count, 1)
self.assertEqual(
solve_log.termination_reason,
result.solve_log.termination_reason,
solve_log_pb2.TERMINATION_REASON_ITERATION_LIMIT,
)

Expand All @@ -210,10 +205,10 @@ def test_solution(self):
opt_criteria = params.termination_criteria.simple_optimality_criteria
opt_criteria.eps_optimal_relative = 0.0
opt_criteria.eps_optimal_absolute = 1.0e-10
result = pdlp.primal_dual_hybrid_gradient(tiny_lp(), params.SerializeToString())
solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str)
result = pdlp.primal_dual_hybrid_gradient(tiny_lp(), params)
self.assertEqual(
solve_log.termination_reason, solve_log_pb2.TERMINATION_REASON_OPTIMAL
result.solve_log.termination_reason,
solve_log_pb2.TERMINATION_REASON_OPTIMAL,
)
self.assertSequenceAlmostEqual(result.primal_solution, [1.0, 0.0, 6.0, 2.0])
self.assertSequenceAlmostEqual(result.dual_solution, [0.5, 4.0, 0.0])
Expand All @@ -224,10 +219,10 @@ def test_solution_2(self):
opt_criteria = params.termination_criteria.simple_optimality_criteria
opt_criteria.eps_optimal_relative = 0.0
opt_criteria.eps_optimal_absolute = 1.0e-10
result = pdlp.primal_dual_hybrid_gradient(test_lp(), params.SerializeToString())
solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str)
result = pdlp.primal_dual_hybrid_gradient(test_lp(), params)
self.assertEqual(
solve_log.termination_reason, solve_log_pb2.TERMINATION_REASON_OPTIMAL
result.solve_log.termination_reason,
solve_log_pb2.TERMINATION_REASON_OPTIMAL,
)
self.assertSequenceAlmostEqual(result.primal_solution, [-1, 8, 1, 2.5])
self.assertSequenceAlmostEqual(result.dual_solution, [-2, 0, 2.375, 2 / 3])
Expand All @@ -244,13 +239,13 @@ def test_starting_point(self):
start.primal_solution = [1.0, 0.0, 6.0, 2.0]
start.dual_solution = [0.5, 4.0, 0.0]
result = pdlp.primal_dual_hybrid_gradient(
tiny_lp(), params.SerializeToString(), initial_solution=start
tiny_lp(), params, initial_solution=start
)
solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str)
self.assertEqual(
solve_log.termination_reason, solve_log_pb2.TERMINATION_REASON_OPTIMAL
result.solve_log.termination_reason,
solve_log_pb2.TERMINATION_REASON_OPTIMAL,
)
self.assertEqual(solve_log.iteration_count, 0)
self.assertEqual(result.solve_log.iteration_count, 0)


if __name__ == "__main__":
Expand Down
7 changes: 3 additions & 4 deletions ortools/pdlp/samples/simple_pdlp_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@ def main() -> None:
params.verbosity_level = 0
params.presolve_options.use_glop = False

# Call the main solve function. Note that a quirk of the pywrap11 API forces
# us to serialize the `params` and deserialize the `solve_log` proto messages.
result = pdlp.primal_dual_hybrid_gradient(simple_lp(), params.SerializeToString())
solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str)
# Call the main solve function.
result = pdlp.primal_dual_hybrid_gradient(simple_lp(), params)
solve_log = result.solve_log

if solve_log.termination_reason == solve_log_pb2.TERMINATION_REASON_OPTIMAL:
print("Solve successful")
Expand Down

0 comments on commit bad3f17

Please sign in to comment.