Skip to content

Commit

Permalink
[CP-SAT] Fix #4373
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Sep 27, 2024
1 parent 3c7bc49 commit c4fac77
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 77 deletions.
26 changes: 26 additions & 0 deletions ortools/algorithms/sparse_permutation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,30 @@ std::string SparsePermutation::DebugString() const {
return out;
}

int SparsePermutation::Image(int element) const {
for (int c = 0; c < NumCycles(); ++c) {
int cur_element = LastElementInCycle(c);
for (int image : Cycle(c)) {
if (cur_element == element) {
return image;
}
cur_element = image;
}
}
return element;
}

int SparsePermutation::InverseImage(int element) const {
for (int c = 0; c < NumCycles(); ++c) {
int cur_element = LastElementInCycle(c);
for (int image : Cycle(c)) {
if (image == element) {
return cur_element;
}
cur_element = image;
}
}
return element;
}

} // namespace operations_research
26 changes: 26 additions & 0 deletions ortools/algorithms/sparse_permutation.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class SparsePermutation {
// information with the loop above. Not sure it is needed though.
int LastElementInCycle(int i) const;

// Returns the image of the given element or `element` itself if it is stable
// under the permutation.
int Image(int element) const;
int InverseImage(int element) const;

// To add a cycle to the permutation, repeatedly call AddToCurrentCycle()
// with the cycle's orbit, then call CloseCurrentCycle();
// This shouldn't be called on trivial cycles (of length 1).
Expand All @@ -76,6 +81,9 @@ class SparsePermutation {
// Example: "(1 4 3) (5 9) (6 8 7)".
std::string DebugString() const;

template <typename Collection>
void ApplyToDenseCollection(Collection& span) const;

private:
const int size_;
std::vector<int> cycles_;
Expand Down Expand Up @@ -129,6 +137,24 @@ inline int SparsePermutation::LastElementInCycle(int i) const {
return cycles_[cycle_ends_[i] - 1];
}

template <typename Collection>
void SparsePermutation::ApplyToDenseCollection(Collection& span) const {
using T = typename Collection::value_type;
for (int c = 0; c < NumCycles(); ++c) {
const int last_element_idx = LastElementInCycle(c);
int element = last_element_idx;
T last_element = span[element];
for (int image : Cycle(c)) {
if (image == last_element_idx) {
span[element] = last_element;
} else {
span[element] = span[image];
}
element = image;
}
}
}

} // namespace operations_research

#endif // OR_TOOLS_ALGORITHMS_SPARSE_PERMUTATION_H_
15 changes: 15 additions & 0 deletions ortools/algorithms/sparse_permutation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <memory>
#include <random>
#include <string>
#include <vector>

#include "absl/container/flat_hash_set.h"
Expand Down Expand Up @@ -73,6 +74,20 @@ TEST(SparsePermutationTest, Identity) {
EXPECT_EQ(0, permutation.NumCycles());
}

TEST(SparsePermutationTest, ApplyToVector) {
std::vector<std::string> v = {"0", "1", "2", "3", "4", "5", "6", "7", "8"};
SparsePermutation permutation(v.size());
permutation.AddToCurrentCycle(4);
permutation.AddToCurrentCycle(2);
permutation.AddToCurrentCycle(7);
permutation.CloseCurrentCycle();
permutation.AddToCurrentCycle(6);
permutation.AddToCurrentCycle(1);
permutation.CloseCurrentCycle();
permutation.ApplyToDenseCollection(v);
EXPECT_THAT(v, ElementsAre("0", "6", "7", "3", "2", "5", "1", "4", "8"));
}

// Generate a bunch of permutation on a 'huge' space, but that have very few
// displacements. This would OOM if the implementation was O(N); we verify
// that it doesn't.
Expand Down
18 changes: 0 additions & 18 deletions ortools/graph/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -355,24 +355,6 @@ cc_library(
],
)

# need C++20
#cc_test(
# name = "k_shortest_paths_test",
# srcs = ["k_shortest_paths_test.cc"],
# deps = [
# ":graph",
# ":io",
# ":k_shortest_paths",
# ":shortest_paths",
# "//ortools/base:gmock_main",
# "@com_google_absl//absl/algorithm:container",
# "@com_google_absl//absl/log:check",
# "@com_google_absl//absl/random:distributions",
# "@com_google_absl//absl/strings",
# "@com_google_benchmark//:benchmark",
# ],
#)

# Flow problem protobuf representation
proto_library(
name = "flow_problem_proto",
Expand Down
4 changes: 3 additions & 1 deletion ortools/sat/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,6 @@ cc_library(
hdrs = ["presolve_context.h"],
deps = [
":cp_model_cc_proto",
":cp_model_checker",
":cp_model_loader",
":cp_model_mapping",
":cp_model_utils",
Expand All @@ -668,6 +667,7 @@ cc_library(
":sat_parameters_cc_proto",
":sat_solver",
":util",
"//ortools/algorithms:sparse_permutation",
"//ortools/base",
"//ortools/base:mathutil",
"//ortools/port:proto_utils",
Expand Down Expand Up @@ -1163,6 +1163,7 @@ cc_library(
"//ortools/algorithms:dynamic_partition",
"//ortools/algorithms:sparse_permutation",
"//ortools/base",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/types:span",
],
Expand All @@ -1176,6 +1177,7 @@ cc_test(
":symmetry_util",
"//ortools/algorithms:sparse_permutation",
"//ortools/base:gmock_main",
"@com_google_absl//absl/types:span",
],
)

Expand Down
63 changes: 61 additions & 2 deletions ortools/sat/cp_model_symmetries.cc
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,47 @@ std::vector<int64_t> BuildInequalityCoeffsForOrbitope(
return out;
}

void UpdateHintAfterFixingBoolToBreakSymmetry(
PresolveContext* context, int var, bool fixed_value,
const std::vector<std::unique_ptr<SparsePermutation>>& generators) {
if (!context->VarHasSolutionHint(var)) {
return;
}
const int64_t hinted_value = context->SolutionHint(var);
if (hinted_value == static_cast<int64_t>(fixed_value)) {
return;
}

std::vector<int> schrier_vector;
std::vector<int> orbit;
GetSchreierVectorAndOrbit(var, generators, &schrier_vector, &orbit);

bool found_target = false;
int target_var;
for (int v : orbit) {
if (context->VarHasSolutionHint(v) &&
context->SolutionHint(v) == static_cast<int64_t>(fixed_value)) {
found_target = true;
target_var = v;
break;
}
}
if (!found_target) {
context->UpdateRuleStats(
"hint: couldn't transform infeasible hint properly");
return;
}

const std::vector<int> generator_idx =
TracePoint(target_var, schrier_vector, generators);
for (const int i : generator_idx) {
context->PermuteHintValues(*generators[i]);
}

DCHECK(context->VarHasSolutionHint(var));
DCHECK_EQ(context->SolutionHint(var), fixed_value);
}

} // namespace

bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) {
Expand Down Expand Up @@ -1010,6 +1051,7 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) {
// fixing do not exploit the full structure of these symmeteries. Note
// however that the fixing via propagation above close cod105 even more
// efficiently.
std::vector<int> var_can_be_true_per_orbit(num_vars, -1);
{
std::vector<int> tmp_to_clear;
std::vector<int> tmp_sizes(num_vars, 0);
Expand Down Expand Up @@ -1050,7 +1092,11 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) {
}

// We push all but the first one in each orbit.
if (tmp_sizes[rep] == 0) can_be_fixed_to_false.push_back(var);
if (tmp_sizes[rep] == 0) {
can_be_fixed_to_false.push_back(var);
} else {
var_can_be_true_per_orbit[rep] = var;
}
tmp_sizes[rep] = 0;
}
} else {
Expand Down Expand Up @@ -1131,7 +1177,7 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) {
}
}

// Supper simple heuristic to use the orbitope or not.
// Super simple heuristic to use the orbitope or not.
//
// In an orbitope with an at most one on each row, we can fix the upper right
// triangle. We could use a formula, but the loop is fast enough.
Expand All @@ -1153,6 +1199,19 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) {
const int var = can_be_fixed_to_false[i];
if (orbits[var] == orbit_index) ++num_in_orbit;
context->UpdateRuleStats("symmetry: fixed to false in general orbit");
if (context->VarHasSolutionHint(var) && context->SolutionHint(var) == 1 &&
var_can_be_true_per_orbit[orbits[var]] != -1) {
// We are breaking the symmetry in a way that makes the hint invalid.
// We want `var` to be false, so we would naively pick a symmetry to
// enforce that. But that will be wrong if we do this twice: after we
// permute the hint to fix the first one we would look for a symmetry
// group element that fixes the second one to false. But there are many
// of those, and picking the wrong one would risk making the first one
// true again. Since this is a AMO, fixing the one that is true doesn't
// have this problem.
UpdateHintAfterFixingBoolToBreakSymmetry(
context, var_can_be_true_per_orbit[orbits[var]], true, generators);
}
if (!context->SetLiteralToFalse(var)) return false;
}

Expand Down
8 changes: 8 additions & 0 deletions ortools/sat/presolve_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "absl/numeric/int128.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "ortools/algorithms/sparse_permutation.h"
#include "ortools/base/logging.h"
#include "ortools/base/mathutil.h"
#include "ortools/port/proto_utils.h"
Expand Down Expand Up @@ -725,6 +726,7 @@ void PresolveContext::UpdateConstraintVariableUsage(int c) {
}

bool PresolveContext::ConstraintVariableGraphIsUpToDate() const {
if (is_unsat_) return true; // We do not care in this case.
return constraint_to_vars_.size() == working_model->constraints_size();
}

Expand Down Expand Up @@ -1016,6 +1018,12 @@ bool PresolveContext::CanonicalizeAffineVariable(int ref, int64_t coeff,
return true;
}

void PresolveContext::PermuteHintValues(const SparsePermutation& perm) {
CHECK(hint_is_loaded_);
perm.ApplyToDenseCollection(hint_);
perm.ApplyToDenseCollection(hint_has_value_);
}

bool PresolveContext::StoreAffineRelation(int ref_x, int ref_y, int64_t coeff,
int64_t offset,
bool debug_no_recursion) {
Expand Down
3 changes: 3 additions & 0 deletions ortools/sat/presolve_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "ortools/algorithms/sparse_permutation.h"
#include "ortools/base/logging.h"
#include "ortools/sat/cp_model.pb.h"
#include "ortools/sat/cp_model_utils.h"
Expand Down Expand Up @@ -574,6 +575,8 @@ class PresolveContext {
// the hint, in order to maintain it as best as possible during presolve.
void LoadSolutionHint();

void PermuteHintValues(const SparsePermutation& perm);

// Solution hint accessor.
bool VarHasSolutionHint(int var) const { return hint_has_value_[var]; }
int64_t SolutionHint(int var) const { return hint_[var]; }
Expand Down
38 changes: 38 additions & 0 deletions ortools/sat/symmetry_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <memory>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "ortools/algorithms/dynamic_partition.h"
Expand Down Expand Up @@ -194,5 +195,42 @@ std::vector<int> GetOrbitopeOrbits(
return orbits;
}

void GetSchreierVectorAndOrbit(
int point, absl::Span<const std::unique_ptr<SparsePermutation>> generators,
std::vector<int>* schrier_vector, std::vector<int>* orbit) {
schrier_vector->clear();
*orbit = {point};
if (generators.empty()) return;
schrier_vector->resize(generators[0]->Size(), -1);
absl::flat_hash_set<int> orbit_set = {point};
for (int i = 0; i < orbit->size(); ++i) {
const int orbit_element = (*orbit)[i];
for (int i = 0; i < generators.size(); ++i) {
DCHECK_EQ(schrier_vector->size(), generators[i]->Size());
const int image = generators[i]->Image(orbit_element);
if (image == orbit_element) continue;
const auto [it, inserted] = orbit_set.insert(image);
if (inserted) {
(*schrier_vector)[image] = i;
orbit->push_back(image);
}
}
}
}

std::vector<int> TracePoint(
int point, absl::Span<const int> schrier_vector,
absl::Span<const std::unique_ptr<SparsePermutation>> generators) {
std::vector<int> result;
while (schrier_vector[point] != -1) {
const SparsePermutation& perm = *generators[schrier_vector[point]];
result.push_back(schrier_vector[point]);
const int next = perm.InverseImage(point);
DCHECK_NE(next, point);
point = next;
}
return result;
}

} // namespace sat
} // namespace operations_research
13 changes: 13 additions & 0 deletions ortools/sat/symmetry_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ std::vector<int> GetOrbits(
std::vector<int> GetOrbitopeOrbits(int n,
absl::Span<const std::vector<int>> orbitope);

// See Chapter 7 of Butler, Gregory, ed. Fundamental algorithms for permutation
// groups. Berlin, Heidelberg: Springer Berlin Heidelberg, 1991.
void GetSchreierVectorAndOrbit(
int point, absl::Span<const std::unique_ptr<SparsePermutation>> generators,
std::vector<int>* schrier_vector, std::vector<int>* orbit);

// Given a schreier vector for a given base point and a point in the same orbit
// of the base point, returns a list of index of the `generators` to apply to
// get a permutation mapping the base point to get the given point.
std::vector<int> TracePoint(
int point, absl::Span<const int> schrier_vector,
absl::Span<const std::unique_ptr<SparsePermutation>> generators);

// Given the generators for a permutation group of [0, n-1], update it to
// a set of generators of the group stabilizing the given element.
//
Expand Down
Loading

0 comments on commit c4fac77

Please sign in to comment.