diff --git a/ortools/pdlp/BUILD.bazel b/ortools/pdlp/BUILD.bazel index be2e1b96ade..b395568d00b 100644 --- a/ortools/pdlp/BUILD.bazel +++ b/ortools/pdlp/BUILD.bazel @@ -110,7 +110,6 @@ cc_library( ":trust_region", "//ortools/base", "//ortools/base:mathutil", - "//ortools/base:status_macros", "//ortools/base:timer", "//ortools/glop:parameters_cc_proto", "//ortools/glop:preprocessor", @@ -151,7 +150,6 @@ cc_test( "//ortools/lp_data", "//ortools/lp_data:base", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@eigen//:eigen3", @@ -182,6 +180,7 @@ cc_test( ":quadratic_program", ":test_util", "//ortools/base:protobuf_util", + "//ortools/base:status_macros", "//ortools/linear_solver:linear_solver_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -200,13 +199,14 @@ cc_library( "//ortools/base:status_macros", "//ortools/linear_solver:linear_solver_cc_proto", "//ortools/linear_solver:model_exporter", - "//ortools/lp_data:mps_reader", + "//ortools/lp_data:mps_reader_template", "//ortools/util:file_util", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@eigen//:eigen3", ], ) @@ -295,6 +295,7 @@ cc_test( ":sharder", "//ortools/base", "//ortools/base:mathutil", + "//ortools/base:threadpool", "@com_google_absl//absl/random:distributions", "@eigen//:eigen3", ], diff --git a/ortools/pdlp/python/BUILD.bazel b/ortools/pdlp/python/BUILD.bazel index 7e2b450b48d..3188de46feb 100644 --- a/ortools/pdlp/python/BUILD.bazel +++ b/ortools/pdlp/python/BUILD.bazel @@ -31,6 +31,7 @@ pybind_extension( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@eigen//:eigen3", + "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], ) diff --git a/ortools/pdlp/python/pdlp.cc b/ortools/pdlp/python/pdlp.cc index bffb7979ac7..ebd0a0db45b 100644 --- a/ortools/pdlp/python/pdlp.cc +++ b/ortools/pdlp/python/pdlp.cc @@ -22,7 +22,6 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "ortools/linear_solver/linear_solver.pb.h" #include "ortools/pdlp/primal_dual_hybrid_gradient.h" #include "ortools/pdlp/quadratic_program.h" @@ -32,28 +31,14 @@ #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl.h" +#include "pybind11_protobuf/native_proto_caster.h" namespace operations_research::pdlp { -namespace { using ::pybind11::arg; -// TODO(user): The interface uses serialized protos because of issues building -// pybind11_protobuf. See -// https://github.com/protocolbuffers/protobuf/issues/9464. After -// pybind11_protobuf is working, this workaround can be removed. - -// A mirror of pdlp::SolverResult except with a serialized SolveLog. -struct PywrapSolverResult { - Eigen::VectorXd primal_solution; - Eigen::VectorXd dual_solution; - Eigen::VectorXd reduced_costs; - pybind11::bytes solve_log_str; -}; - -} // namespace - PYBIND11_MODULE(pdlp, m) { + pybind11_protobuf::ImportNativeProtoCasters(); // --------------------------------------------------------------------------- // quadratic_program.h // --------------------------------------------------------------------------- @@ -110,12 +95,8 @@ PYBIND11_MODULE(pdlp, m) { m.def( "qp_from_mpmodel_proto", - [](absl::string_view proto_str, bool relax_integer_variables, + [](const MPModelProto& proto, bool relax_integer_variables, bool include_names) { - MPModelProto proto; - if (!proto.ParseFromString(std::string(proto_str))) { - throw std::invalid_argument("Unable to parse input proto"); - } absl::StatusOr qp = QpFromMpModelProto(proto, relax_integer_variables, include_names); if (qp.ok()) { @@ -130,7 +111,7 @@ PYBIND11_MODULE(pdlp, m) { m.def("qp_to_mpmodel_proto", [](const QuadraticProgram& qp) { absl::StatusOr proto = QpToMpModelProto(qp); if (proto.ok()) { - return pybind11::bytes(proto->SerializeAsString()); + return *proto; } else { throw std::invalid_argument(absl::StrCat(proto.status().message())); } @@ -152,29 +133,20 @@ PYBIND11_MODULE(pdlp, m) { .def_readwrite("primal_solution", &PrimalAndDualSolution::primal_solution) .def_readwrite("dual_solution", &PrimalAndDualSolution::dual_solution); - pybind11::class_(m, "SolverResult") + pybind11::class_(m, "SolverResult") .def(pybind11::init<>()) - .def_readwrite("primal_solution", &PywrapSolverResult::primal_solution) - .def_readwrite("dual_solution", &PywrapSolverResult::dual_solution) - .def_readwrite("reduced_costs", &PywrapSolverResult::reduced_costs) - .def_readwrite("solve_log_str", &PywrapSolverResult::solve_log_str); + .def_readwrite("primal_solution", &SolverResult::primal_solution) + .def_readwrite("dual_solution", &SolverResult::dual_solution) + .def_readwrite("reduced_costs", &SolverResult::reduced_costs) + .def_readwrite("solve_log", &SolverResult::solve_log); // TODO(user): Expose interrupt_solve and iteration_stats_callback. m.def( "primal_dual_hybrid_gradient", - [](QuadraticProgram qp, absl::string_view params_str, + [](QuadraticProgram qp, PrimalDualHybridGradientParams params, std::optional initial_solution) { - PrimalDualHybridGradientParams params; - if (!params.ParseFromString(std::string(params_str))) { - throw std::invalid_argument("Unable to parse input params"); - } - SolverResult result = PrimalDualHybridGradient( - std::move(qp), params, std::move(initial_solution)); - return PywrapSolverResult{ - .primal_solution = std::move(result.primal_solution), - .dual_solution = std::move(result.dual_solution), - .reduced_costs = std::move(result.reduced_costs), - .solve_log_str = result.solve_log.SerializeAsString()}; + return PrimalDualHybridGradient(std::move(qp), params, + std::move(initial_solution)); }, arg("qp"), arg("params"), arg("initial_solution") = std::nullopt); } diff --git a/ortools/pdlp/python/pdlp_test.py b/ortools/pdlp/python/pdlp_test.py index c19a2709627..05561ed36d0 100644 --- a/ortools/pdlp/python/pdlp_test.py +++ b/ortools/pdlp/python/pdlp_test.py @@ -84,18 +84,14 @@ def test_validate_quadratic_program_dimensions_for_empty_qp(self): def test_converts_from_tiny_mpmodel_lp(self): lp_proto = small_proto_lp() - qp = pdlp.qp_from_mpmodel_proto( - lp_proto.SerializeToString(), relax_integer_variables=False - ) + qp = pdlp.qp_from_mpmodel_proto(lp_proto, relax_integer_variables=False) pdlp.validate_quadratic_program_dimensions(qp) self.assertTrue(pdlp.is_linear_program(qp)) self.assertSameElements(qp.objective_vector, [0, -2]) def test_converts_from_tiny_mpmodel_qp(self): qp_proto = small_proto_qp() - qp = pdlp.qp_from_mpmodel_proto( - qp_proto.SerializeToString(), relax_integer_variables=False - ) + qp = pdlp.qp_from_mpmodel_proto(qp_proto, relax_integer_variables=False) pdlp.validate_quadratic_program_dimensions(qp) self.assertFalse(pdlp.is_linear_program(qp)) self.assertSameElements(qp.objective_vector, [0, 0]) @@ -110,7 +106,7 @@ def test_build_lp(self): qp.variable_upper_bounds = [np.inf, np.inf] qp.variable_names = ["x", "y"] self.assertEqual( - linear_solver_pb2.MPModelProto.FromString(pdlp.qp_to_mpmodel_proto(qp)), + pdlp.qp_to_mpmodel_proto(qp), small_proto_lp(), ) @@ -125,7 +121,7 @@ def test_build_qp(self): qp.variable_upper_bounds = [np.inf, np.inf] qp.variable_names = ["x", "y"] self.assertEqual( - linear_solver_pb2.MPModelProto.FromString(pdlp.qp_to_mpmodel_proto(qp)), + pdlp.qp_to_mpmodel_proto(qp), small_proto_qp(), ) @@ -197,11 +193,10 @@ def test_iteration_limit(self): params = solvers_pb2.PrimalDualHybridGradientParams() params.termination_criteria.iteration_limit = 1 params.termination_check_frequency = 1 - result = pdlp.primal_dual_hybrid_gradient(tiny_lp(), params.SerializeToString()) - solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str) - self.assertLessEqual(solve_log.iteration_count, 1) + result = pdlp.primal_dual_hybrid_gradient(tiny_lp(), params) + self.assertLessEqual(result.solve_log.iteration_count, 1) self.assertEqual( - solve_log.termination_reason, + result.solve_log.termination_reason, solve_log_pb2.TERMINATION_REASON_ITERATION_LIMIT, ) @@ -210,10 +205,10 @@ def test_solution(self): opt_criteria = params.termination_criteria.simple_optimality_criteria opt_criteria.eps_optimal_relative = 0.0 opt_criteria.eps_optimal_absolute = 1.0e-10 - result = pdlp.primal_dual_hybrid_gradient(tiny_lp(), params.SerializeToString()) - solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str) + result = pdlp.primal_dual_hybrid_gradient(tiny_lp(), params) self.assertEqual( - solve_log.termination_reason, solve_log_pb2.TERMINATION_REASON_OPTIMAL + result.solve_log.termination_reason, + solve_log_pb2.TERMINATION_REASON_OPTIMAL, ) self.assertSequenceAlmostEqual(result.primal_solution, [1.0, 0.0, 6.0, 2.0]) self.assertSequenceAlmostEqual(result.dual_solution, [0.5, 4.0, 0.0]) @@ -224,10 +219,10 @@ def test_solution_2(self): opt_criteria = params.termination_criteria.simple_optimality_criteria opt_criteria.eps_optimal_relative = 0.0 opt_criteria.eps_optimal_absolute = 1.0e-10 - result = pdlp.primal_dual_hybrid_gradient(test_lp(), params.SerializeToString()) - solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str) + result = pdlp.primal_dual_hybrid_gradient(test_lp(), params) self.assertEqual( - solve_log.termination_reason, solve_log_pb2.TERMINATION_REASON_OPTIMAL + result.solve_log.termination_reason, + solve_log_pb2.TERMINATION_REASON_OPTIMAL, ) self.assertSequenceAlmostEqual(result.primal_solution, [-1, 8, 1, 2.5]) self.assertSequenceAlmostEqual(result.dual_solution, [-2, 0, 2.375, 2 / 3]) @@ -244,13 +239,13 @@ def test_starting_point(self): start.primal_solution = [1.0, 0.0, 6.0, 2.0] start.dual_solution = [0.5, 4.0, 0.0] result = pdlp.primal_dual_hybrid_gradient( - tiny_lp(), params.SerializeToString(), initial_solution=start + tiny_lp(), params, initial_solution=start ) - solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str) self.assertEqual( - solve_log.termination_reason, solve_log_pb2.TERMINATION_REASON_OPTIMAL + result.solve_log.termination_reason, + solve_log_pb2.TERMINATION_REASON_OPTIMAL, ) - self.assertEqual(solve_log.iteration_count, 0) + self.assertEqual(result.solve_log.iteration_count, 0) if __name__ == "__main__": diff --git a/ortools/pdlp/samples/simple_pdlp_program.py b/ortools/pdlp/samples/simple_pdlp_program.py index 2afabfb3a0d..a6250a0fe95 100644 --- a/ortools/pdlp/samples/simple_pdlp_program.py +++ b/ortools/pdlp/samples/simple_pdlp_program.py @@ -71,10 +71,9 @@ def main() -> None: params.verbosity_level = 0 params.presolve_options.use_glop = False - # Call the main solve function. Note that a quirk of the pywrap11 API forces - # us to serialize the `params` and deserialize the `solve_log` proto messages. - result = pdlp.primal_dual_hybrid_gradient(simple_lp(), params.SerializeToString()) - solve_log = solve_log_pb2.SolveLog.FromString(result.solve_log_str) + # Call the main solve function. + result = pdlp.primal_dual_hybrid_gradient(simple_lp(), params) + solve_log = result.solve_log if solve_log.termination_reason == solve_log_pb2.TERMINATION_REASON_OPTIMAL: print("Solve successful")