From 5986fe4ead0fc864c19f6b864d49e1b0387e1905 Mon Sep 17 00:00:00 2001 From: Alexandre Dubray Date: Mon, 7 Aug 2023 11:00:54 +0200 Subject: [PATCH] Better handling of propagator type + propagation cutoff in approximate search --- src/compiler/exact.rs | 6 ++-- src/core/components.rs | 7 +++-- src/main.rs | 8 ++--- src/parser.rs | 4 +-- src/propagator.rs | 44 +++++++++++++++++++-------- src/search/approximate.rs | 62 +++++++++++++++++--------------------- src/search/sequential.rs | 6 ++-- tests/integration_tests.rs | 18 +++++------ 8 files changed, 83 insertions(+), 72 deletions(-) diff --git a/src/compiler/exact.rs b/src/compiler/exact.rs index 365a927..8a39b90 100644 --- a/src/compiler/exact.rs +++ b/src/compiler/exact.rs @@ -31,7 +31,7 @@ use search_trail::{StateManager, SaveAndRestore}; use crate::core::components::{ComponentExtractor, ComponentIndex}; use crate::core::graph::*; use crate::heuristics::branching::BranchingDecision; -use crate::propagator::FTReachablePropagator; +use crate::propagator::CompiledPropagator; use crate::common::*; use crate::compiler::circuit::*; @@ -50,7 +50,7 @@ where /// Heuristics that decide on which distribution to branch next branching_heuristic: &'b mut B, /// The propagator - propagator: FTReachablePropagator, + propagator: CompiledPropagator, /// Cache used to store results of sub-problems cache: FxHashMap>, } @@ -64,7 +64,7 @@ where state: StateManager, component_extractor: ComponentExtractor, branching_heuristic: &'b mut B, - propagator: FTReachablePropagator, + propagator: CompiledPropagator, ) -> Self { let cache = FxHashMap::default(); Self { diff --git a/src/core/components.rs b/src/core/components.rs index b367940..b87ef50 100644 --- a/src/core/components.rs +++ b/src/core/components.rs @@ -210,7 +210,7 @@ impl ComponentExtractor { /// This function is responsible of updating the data structure with the new connected /// components in `g` given its current assignments. /// Returns true iff at least one component has been detected and it contains one distribution - pub fn detect_components( + pub fn detect_components( &mut self, g: &mut Graph, state: &mut StateManager, @@ -344,6 +344,7 @@ mod test_component_detection { use crate::core::graph::{Graph, VariableIndex, ClauseIndex}; use crate::core::components::*; use search_trail::{StateManager, SaveAndRestore}; + use crate::propagator::SearchPropagator; // Graph used for the tests: // @@ -397,7 +398,7 @@ mod test_component_detection { let mut state = StateManager::default(); let mut g = get_graph(&mut state); let mut extractor = ComponentExtractor::new(&g, &mut state); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = SearchPropagator::new(); g.set_clause_unconstrained(ClauseIndex(4), &mut state); extractor.detect_components(&mut g, &mut state, ComponentIndex(0), &mut propagator); @@ -418,7 +419,7 @@ mod test_component_detection { let mut state = StateManager::default(); let mut g = get_graph(&mut state); let mut extractor = ComponentExtractor::new(&g, &mut state); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = SearchPropagator::new(); state.save_state(); diff --git a/src/main.rs b/src/main.rs index 67a1f6b..1910a30 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,7 +33,7 @@ use crate::core::components::ComponentExtractor; use parser::*; use heuristics::branching::*; use search::{ExactDefaultSolver, ExactQuietSolver, ApproximateDefaultSolver, ApproximateQuietSolver}; -use propagator::FTReachablePropagator; +use propagator::{SearchPropagator, CompiledPropagator, MixedPropagator}; use compiler::exact::ExactDACCompiler; use compiler::circuit::Dac; @@ -132,7 +132,7 @@ fn read_compiled(input: PathBuf, dotfile: Option) { fn run_compilation(input: PathBuf, branching: Branching, fdac: Option, dotfile: Option) { let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = CompiledPropagator::new(); let graph = graph_from_ppidimacs(&input, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); let mut branching_heuristic: Box = match branching { @@ -170,7 +170,7 @@ fn run_compilation(input: PathBuf, branching: Branching, fdac: Option, fn run_approx_search(input: PathBuf, branching: Branching, statistics: bool, memory: Option, epsilon: f64) { let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = MixedPropagator::new(); let graph = graph_from_ppidimacs(&input, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); let mut branching_heuristic: Box = match branching { @@ -212,7 +212,7 @@ fn run_approx_search(input: PathBuf, branching: Branching, statistics: bool, mem fn run_search(input: PathBuf, branching: Branching, statistics: bool, memory: Option) { let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = SearchPropagator::new(); let graph = graph_from_ppidimacs(&input, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); let mut branching_heuristic: Box = match branching { diff --git a/src/parser.rs b/src/parser.rs index f67c39f..f0de362 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -49,7 +49,7 @@ use std::fs::File; use std::io::{BufRead, BufReader}; use std::path::PathBuf; -pub fn graph_from_ppidimacs( +pub fn graph_from_ppidimacs( filepath: &PathBuf, state: &mut StateManager, propagator: &mut FTReachablePropagator, @@ -77,7 +77,7 @@ pub fn graph_from_ppidimacs( .collect::>(); let nodes = g.add_distribution(&split, state); for i in 0..split.len() { - if !C && split[i] == 1.0 { + if propagator.is_search() && split[i] == 1.0 { propagator.add_to_propagation_stack(nodes[i], true); } } diff --git a/src/propagator.rs b/src/propagator.rs index 71f52e7..b5060d3 100644 --- a/src/propagator.rs +++ b/src/propagator.rs @@ -57,7 +57,15 @@ pub struct Unsat; pub type PropagationResult = Result<(), Unsat>; -pub struct FTReachablePropagator { +const COMPILED_PROPAGATION: u8 = 1; +const SEARCH_PROPAGATION: u8 = 2; +const MIXED_PROPAGATION: u8 = 3; + +pub type CompiledPropagator = FTReachablePropagator; +pub type SearchPropagator = FTReachablePropagator; +pub type MixedPropagator = FTReachablePropagator; + +pub struct FTReachablePropagator { propagation_stack: Vec<(VariableIndex, bool)>, pub unconstrained_clauses: Vec, t_reachable: Vec, @@ -67,13 +75,13 @@ pub struct FTReachablePropagator { propagation_prob: Float, } -impl Default for FTReachablePropagator { +impl Default for FTReachablePropagator { fn default() -> Self { Self::new() } } -impl FTReachablePropagator { +impl FTReachablePropagator { pub fn new() -> Self { Self { @@ -87,6 +95,10 @@ impl FTReachablePropagator { } } + pub fn is_search(&self) -> bool { + C == SEARCH_PROPAGATION || C == MIXED_PROPAGATION + } + /// Sets the number of clauses for the f-reachable and t-reachable vectors pub fn set_number_clauses(&mut self, n: usize) { self.t_reachable.resize(n, false); @@ -112,7 +124,7 @@ impl FTReachablePropagator { self.unconstrained_clauses.push(clause); } } - + /// Returns the propagation probability of the last call to propagate pub fn get_propagation_prob(&self) -> &Float { &self.propagation_prob @@ -141,14 +153,17 @@ impl FTReachablePropagator { /// Computes the unconstrained probability of a distribution. When a distribution does not appear anymore in any constrained /// clauses, the probability of branching on it can be pre-computed. This is what this function returns. fn propagate_unconstrained_distribution(&mut self, g: &Graph, distribution: DistributionIndex, state: &StateManager) { - if C { - self.unconstrained_distributions.push(distribution); - } else if g.distribution_number_false(distribution, state) != 0 { - let mut p = f128!(0.0); - for w in g.distribution_variable_iter(distribution).filter(|v| !g.is_variable_fixed(*v, state)).map(|v| g.get_variable_weight(v).unwrap()) { - p += w; + if g.distribution_number_false(distribution, state) != 0 { + if C == COMPILED_PROPAGATION { + self.unconstrained_distributions.push(distribution); + } + if C == SEARCH_PROPAGATION || C == MIXED_PROPAGATION { + let mut p = f128!(0.0); + for w in g.distribution_variable_iter(distribution).filter(|v| !g.is_variable_fixed(*v, state)).map(|v| g.get_variable_weight(v).unwrap()) { + p += w; + } + self.propagation_prob *= &p; } - self.propagation_prob *= &p; } } @@ -242,11 +257,14 @@ impl FTReachablePropagator { if is_p { let distribution = g.get_variable_distribution(variable).unwrap(); - if C { + if C == COMPILED_PROPAGATION { + self.assignments.push((distribution, variable, value)); + } + if C == MIXED_PROPAGATION && !value { self.assignments.push((distribution, variable, value)); } if value { - if !C { + if C == SEARCH_PROPAGATION || C == MIXED_PROPAGATION { // If the solver is in search-mode, then we can return as soon as the computed probability is 0. // But in compilation mode, we can not know in advance if the compiled circuit will be used in a // learning framework in which the probabilities might change. diff --git a/src/search/approximate.rs b/src/search/approximate.rs index 3229e4b..f8b6efe 100644 --- a/src/search/approximate.rs +++ b/src/search/approximate.rs @@ -30,7 +30,7 @@ use search_trail::{StateManager, SaveAndRestore}; use crate::core::components::{ComponentExtractor, ComponentIndex}; use crate::core::graph::*; use crate::heuristics::branching::BranchingDecision; -use crate::propagator::FTReachablePropagator; +use crate::propagator::MixedPropagator; use crate::search::statistics::Statistics; use crate::common::*; @@ -63,7 +63,7 @@ where /// Heuristics that decide on which distribution to branch next branching_heuristic: &'b mut B, /// The propagator - propagator: FTReachablePropagator, + propagator: MixedPropagator, /// Cache used to store results of sub-problems cache: FxHashMap, /// Statistics collectors @@ -85,7 +85,7 @@ where state: StateManager, component_extractor: ComponentExtractor, branching_heuristic: &'b mut B, - propagator: FTReachablePropagator, + propagator: MixedPropagator, mlimit: u64, epsilon: f64, ) -> Self { @@ -153,44 +153,36 @@ where p_out += v_weight; }, Ok(_) => { - let mut added_proba = f128!(1.0); - let mut removed_proba = f128!(1.0); - let mut has_removed = false; - self.distribution_out_vec.fill(0.0); - for d in self.propagator.unconstrained_distributions_iter() { - let mut p_unconstrained = 0.0; - if self.graph.distribution_number_false(d, &self.state) != 0 { - for variable in self.graph.distribution_variable_iter(d) { - if !self.graph.is_variable_fixed(variable, &self.state) { - p_unconstrained += self.graph.get_variable_weight(variable).unwrap(); - } + let v = self.propagator.get_propagation_prob().clone(); + if v != 0.0 { + let mut removed_proba = f128!(1.0); + let mut has_removed = false; + self.distribution_out_vec.fill(0.0); + for (d, variable, value) in self.propagator.assignments_iter().filter(|a| a.0 != distribution) { + let weight = self.graph.get_variable_weight(variable).unwrap(); + if !value { + has_removed = true; + self.distribution_out_vec[d.0] += weight; } - added_proba *= p_unconstrained; } - } - for (d, variable, value) in self.propagator.assignments_iter().filter(|a| a.0 != distribution) { - let weight = self.graph.get_variable_weight(variable).unwrap(); - if value { - added_proba *= weight; - } else { - has_removed = true; - self.distribution_out_vec[d.0] += weight; + if has_removed { + for v in self.distribution_out_vec.iter().copied() { + removed_proba *= 1.0 - v; + } } - } - if has_removed { - for v in self.distribution_out_vec.iter().copied() { - removed_proba *= 1.0 - v; + + let child_sol = self._solve(component); + p_in += child_sol.0 * &v; + p_out += child_sol.1 * &v + v_weight * (1.0 - removed_proba.clone()); + p += child_sol.2 * &v; + if let Some(proba) = self.approximate_count(p_in.clone(), p_out.clone()) { + self.state.restore_state(); + return (p_in, p_out, proba); } + } else { + p_out += v_weight; } - let child_sol = self._solve(component); - p_in += child_sol.0 * v_weight * &added_proba; - p_out += child_sol.1 * v_weight * &added_proba + v_weight * (1.0 - removed_proba.clone()); - p += child_sol.2 * v_weight * &added_proba; - if let Some(proba) = self.approximate_count(p_in.clone(), p_out.clone()) { - self.state.restore_state(); - return (p_in, p_out, proba); - } } }; self.state.restore_state(); diff --git a/src/search/sequential.rs b/src/search/sequential.rs index c220725..9b3b2ce 100644 --- a/src/search/sequential.rs +++ b/src/search/sequential.rs @@ -30,7 +30,7 @@ use search_trail::{StateManager, SaveAndRestore}; use crate::core::components::{ComponentExtractor, ComponentIndex}; use crate::core::graph::*; use crate::heuristics::branching::BranchingDecision; -use crate::propagator::FTReachablePropagator; +use crate::propagator::SearchPropagator; use crate::search::statistics::Statistics; use crate::common::*; @@ -58,7 +58,7 @@ where /// Heuristics that decide on which distribution to branch next branching_heuristic: &'b mut B, /// The propagator - propagator: FTReachablePropagator, + propagator: SearchPropagator, /// Cache used to store results of sub-problems cache: FxHashMap, /// Statistics collectors @@ -76,7 +76,7 @@ where state: StateManager, component_extractor: ComponentExtractor, branching_heuristic: &'b mut B, - propagator: FTReachablePropagator, + propagator: SearchPropagator, mlimit: u64, ) -> Self { let cache = FxHashMap::default(); diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 3f39021..24ce1cf 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -2,7 +2,7 @@ use rug::Float; use schlandals; use schlandals::branching::*; -use schlandals::propagator::FTReachablePropagator; +use propagator::{SearchPropagator, CompiledPropagator, MixedPropagator}; use schlandals::components::*; use schlandals::*; use schlandals::search::ExactQuietSolver; @@ -23,7 +23,7 @@ macro_rules! test_input_with_branching { fn []() { let filename = format!("tests/instances/{}/{}.cnf", stringify!($dir), stringify!($name)); let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = SearchPropagator::new(); let path = PathBuf::from(filename); let graph = graph_from_ppidimacs(&path, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); @@ -38,7 +38,7 @@ macro_rules! test_input_with_branching { fn []() { let filename = format!("tests/instances/{}/{}.cnf", stringify!($dir), stringify!($name)); let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = CompiledPropagator::new(); let path = PathBuf::from(filename); let graph = graph_from_ppidimacs(&path, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); @@ -55,7 +55,7 @@ macro_rules! test_input_with_branching { fn []() { let filename = format!("tests/instances/{}/{}.cnf", stringify!($dir), stringify!($name)); let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = CompiledPropagator::new(); let path = PathBuf::from(filename); let graph = graph_from_ppidimacs(&path, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); @@ -87,7 +87,7 @@ macro_rules! test_approximate_input_with_branching { let epsilon = 0.0; let filename = format!("tests/instances/{}/{}.cnf", stringify!($dir), stringify!($name)); let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = MixedPropagator::new(); let path = PathBuf::from(filename); let graph = graph_from_ppidimacs(&path, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); @@ -103,7 +103,7 @@ macro_rules! test_approximate_input_with_branching { let epsilon = 0.05; let filename = format!("tests/instances/{}/{}.cnf", stringify!($dir), stringify!($name)); let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = MixedPropagator::new(); let path = PathBuf::from(filename); let graph = graph_from_ppidimacs(&path, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); @@ -119,7 +119,7 @@ macro_rules! test_approximate_input_with_branching { let epsilon = 0.2; let filename = format!("tests/instances/{}/{}.cnf", stringify!($dir), stringify!($name)); let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = MixedPropagator::new(); let path = PathBuf::from(filename); let graph = graph_from_ppidimacs(&path, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); @@ -135,7 +135,7 @@ macro_rules! test_approximate_input_with_branching { let epsilon = 0.5; let filename = format!("tests/instances/{}/{}.cnf", stringify!($dir), stringify!($name)); let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = MixedPropagator::new(); let path = PathBuf::from(filename); let graph = graph_from_ppidimacs(&path, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state); @@ -151,7 +151,7 @@ macro_rules! test_approximate_input_with_branching { let epsilon = 1.0; let filename = format!("tests/instances/{}/{}.cnf", stringify!($dir), stringify!($name)); let mut state = StateManager::default(); - let mut propagator = FTReachablePropagator::::new(); + let mut propagator = MixedPropagator::new(); let path = PathBuf::from(filename); let graph = graph_from_ppidimacs(&path, &mut state, &mut propagator); let component_extractor = ComponentExtractor::new(&graph, &mut state);