Skip to content

Commit

Permalink
Better handling of propagator type + propagation cutoff in approximat…
Browse files Browse the repository at this point in the history
…e search
  • Loading branch information
AlexandreDubray committed Aug 7, 2023
1 parent 07be288 commit 5986fe4
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 72 deletions.
6 changes: 3 additions & 3 deletions src/compiler/exact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use search_trail::{StateManager, SaveAndRestore};
use crate::core::components::{ComponentExtractor, ComponentIndex};
use crate::core::graph::*;
use crate::heuristics::branching::BranchingDecision;
use crate::propagator::FTReachablePropagator;
use crate::propagator::CompiledPropagator;
use crate::common::*;
use crate::compiler::circuit::*;

Expand All @@ -50,7 +50,7 @@ where
/// Heuristics that decide on which distribution to branch next
branching_heuristic: &'b mut B,
/// The propagator
propagator: FTReachablePropagator<true>,
propagator: CompiledPropagator,
/// Cache used to store results of sub-problems
cache: FxHashMap<CacheEntry, Option<CircuitNodeIndex>>,
}
Expand All @@ -64,7 +64,7 @@ where
state: StateManager,
component_extractor: ComponentExtractor,
branching_heuristic: &'b mut B,
propagator: FTReachablePropagator<true>,
propagator: CompiledPropagator,
) -> Self {
let cache = FxHashMap::default();
Self {
Expand Down
7 changes: 4 additions & 3 deletions src/core/components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ impl ComponentExtractor {
/// This function is responsible of updating the data structure with the new connected
/// components in `g` given its current assignments.
/// Returns true iff at least one component has been detected and it contains one distribution
pub fn detect_components<const C: bool>(
pub fn detect_components<const C: u8>(
&mut self,
g: &mut Graph,
state: &mut StateManager,
Expand Down Expand Up @@ -344,6 +344,7 @@ mod test_component_detection {
use crate::core::graph::{Graph, VariableIndex, ClauseIndex};
use crate::core::components::*;
use search_trail::{StateManager, SaveAndRestore};
use crate::propagator::SearchPropagator;

// Graph used for the tests:
//
Expand Down Expand Up @@ -397,7 +398,7 @@ mod test_component_detection {
let mut state = StateManager::default();
let mut g = get_graph(&mut state);
let mut extractor = ComponentExtractor::new(&g, &mut state);
let mut propagator = FTReachablePropagator::<false>::new();
let mut propagator = SearchPropagator::new();

g.set_clause_unconstrained(ClauseIndex(4), &mut state);
extractor.detect_components(&mut g, &mut state, ComponentIndex(0), &mut propagator);
Expand All @@ -418,7 +419,7 @@ mod test_component_detection {
let mut state = StateManager::default();
let mut g = get_graph(&mut state);
let mut extractor = ComponentExtractor::new(&g, &mut state);
let mut propagator = FTReachablePropagator::<false>::new();
let mut propagator = SearchPropagator::new();

state.save_state();

Expand Down
8 changes: 4 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::core::components::ComponentExtractor;
use parser::*;
use heuristics::branching::*;
use search::{ExactDefaultSolver, ExactQuietSolver, ApproximateDefaultSolver, ApproximateQuietSolver};
use propagator::FTReachablePropagator;
use propagator::{SearchPropagator, CompiledPropagator, MixedPropagator};
use compiler::exact::ExactDACCompiler;
use compiler::circuit::Dac;

Expand Down Expand Up @@ -132,7 +132,7 @@ fn read_compiled(input: PathBuf, dotfile: Option<PathBuf>) {

fn run_compilation(input: PathBuf, branching: Branching, fdac: Option<PathBuf>, dotfile: Option<PathBuf>) {
let mut state = StateManager::default();
let mut propagator = FTReachablePropagator::<true>::new();
let mut propagator = CompiledPropagator::new();
let graph = graph_from_ppidimacs(&input, &mut state, &mut propagator);
let component_extractor = ComponentExtractor::new(&graph, &mut state);
let mut branching_heuristic: Box<dyn BranchingDecision> = match branching {
Expand Down Expand Up @@ -170,7 +170,7 @@ fn run_compilation(input: PathBuf, branching: Branching, fdac: Option<PathBuf>,

fn run_approx_search(input: PathBuf, branching: Branching, statistics: bool, memory: Option<u64>, epsilon: f64) {
let mut state = StateManager::default();
let mut propagator = FTReachablePropagator::<true>::new();
let mut propagator = MixedPropagator::new();
let graph = graph_from_ppidimacs(&input, &mut state, &mut propagator);
let component_extractor = ComponentExtractor::new(&graph, &mut state);
let mut branching_heuristic: Box<dyn BranchingDecision> = match branching {
Expand Down Expand Up @@ -212,7 +212,7 @@ fn run_approx_search(input: PathBuf, branching: Branching, statistics: bool, mem

fn run_search(input: PathBuf, branching: Branching, statistics: bool, memory: Option<u64>) {
let mut state = StateManager::default();
let mut propagator = FTReachablePropagator::<false>::new();
let mut propagator = SearchPropagator::new();
let graph = graph_from_ppidimacs(&input, &mut state, &mut propagator);
let component_extractor = ComponentExtractor::new(&graph, &mut state);
let mut branching_heuristic: Box<dyn BranchingDecision> = match branching {
Expand Down
4 changes: 2 additions & 2 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;

pub fn graph_from_ppidimacs<const C: bool>(
pub fn graph_from_ppidimacs<const C: u8>(
filepath: &PathBuf,
state: &mut StateManager,
propagator: &mut FTReachablePropagator<C>,
Expand Down Expand Up @@ -77,7 +77,7 @@ pub fn graph_from_ppidimacs<const C: bool>(
.collect::<Vec<f64>>();
let nodes = g.add_distribution(&split, state);
for i in 0..split.len() {
if !C && split[i] == 1.0 {
if propagator.is_search() && split[i] == 1.0 {
propagator.add_to_propagation_stack(nodes[i], true);
}
}
Expand Down
44 changes: 31 additions & 13 deletions src/propagator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ pub struct Unsat;

pub type PropagationResult = Result<(), Unsat>;

pub struct FTReachablePropagator<const C: bool> {
const COMPILED_PROPAGATION: u8 = 1;
const SEARCH_PROPAGATION: u8 = 2;
const MIXED_PROPAGATION: u8 = 3;

pub type CompiledPropagator = FTReachablePropagator<COMPILED_PROPAGATION>;
pub type SearchPropagator = FTReachablePropagator<SEARCH_PROPAGATION>;
pub type MixedPropagator = FTReachablePropagator<MIXED_PROPAGATION>;

pub struct FTReachablePropagator<const C: u8> {
propagation_stack: Vec<(VariableIndex, bool)>,
pub unconstrained_clauses: Vec<ClauseIndex>,
t_reachable: Vec<bool>,
Expand All @@ -67,13 +75,13 @@ pub struct FTReachablePropagator<const C: bool> {
propagation_prob: Float,
}

impl<const C: bool> Default for FTReachablePropagator<C> {
impl<const C: u8> Default for FTReachablePropagator<C> {
fn default() -> Self {
Self::new()
}
}

impl<const C: bool> FTReachablePropagator<C> {
impl<const C: u8> FTReachablePropagator<C> {

pub fn new() -> Self {
Self {
Expand All @@ -87,6 +95,10 @@ impl<const C: bool> FTReachablePropagator<C> {
}
}

pub fn is_search(&self) -> bool {
C == SEARCH_PROPAGATION || C == MIXED_PROPAGATION
}

/// Sets the number of clauses for the f-reachable and t-reachable vectors
pub fn set_number_clauses(&mut self, n: usize) {
self.t_reachable.resize(n, false);
Expand All @@ -112,7 +124,7 @@ impl<const C: bool> FTReachablePropagator<C> {
self.unconstrained_clauses.push(clause);
}
}

/// Returns the propagation probability of the last call to propagate
pub fn get_propagation_prob(&self) -> &Float {
&self.propagation_prob
Expand Down Expand Up @@ -141,14 +153,17 @@ impl<const C: bool> FTReachablePropagator<C> {
/// Computes the unconstrained probability of a distribution. When a distribution does not appear anymore in any constrained
/// clauses, the probability of branching on it can be pre-computed. This is what this function returns.
fn propagate_unconstrained_distribution(&mut self, g: &Graph, distribution: DistributionIndex, state: &StateManager) {
if C {
self.unconstrained_distributions.push(distribution);
} else if g.distribution_number_false(distribution, state) != 0 {
let mut p = f128!(0.0);
for w in g.distribution_variable_iter(distribution).filter(|v| !g.is_variable_fixed(*v, state)).map(|v| g.get_variable_weight(v).unwrap()) {
p += w;
if g.distribution_number_false(distribution, state) != 0 {
if C == COMPILED_PROPAGATION {
self.unconstrained_distributions.push(distribution);
}
if C == SEARCH_PROPAGATION || C == MIXED_PROPAGATION {
let mut p = f128!(0.0);
for w in g.distribution_variable_iter(distribution).filter(|v| !g.is_variable_fixed(*v, state)).map(|v| g.get_variable_weight(v).unwrap()) {
p += w;
}
self.propagation_prob *= &p;
}
self.propagation_prob *= &p;
}
}

Expand Down Expand Up @@ -242,11 +257,14 @@ impl<const C: bool> FTReachablePropagator<C> {

if is_p {
let distribution = g.get_variable_distribution(variable).unwrap();
if C {
if C == COMPILED_PROPAGATION {
self.assignments.push((distribution, variable, value));
}
if C == MIXED_PROPAGATION && !value {
self.assignments.push((distribution, variable, value));
}
if value {
if !C {
if C == SEARCH_PROPAGATION || C == MIXED_PROPAGATION {
// If the solver is in search-mode, then we can return as soon as the computed probability is 0.
// But in compilation mode, we can not know in advance if the compiled circuit will be used in a
// learning framework in which the probabilities might change.
Expand Down
62 changes: 27 additions & 35 deletions src/search/approximate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use search_trail::{StateManager, SaveAndRestore};
use crate::core::components::{ComponentExtractor, ComponentIndex};
use crate::core::graph::*;
use crate::heuristics::branching::BranchingDecision;
use crate::propagator::FTReachablePropagator;
use crate::propagator::MixedPropagator;
use crate::search::statistics::Statistics;
use crate::common::*;

Expand Down Expand Up @@ -63,7 +63,7 @@ where
/// Heuristics that decide on which distribution to branch next
branching_heuristic: &'b mut B,
/// The propagator
propagator: FTReachablePropagator<true>,
propagator: MixedPropagator,
/// Cache used to store results of sub-problems
cache: FxHashMap<CacheEntry, NodeSolution>,
/// Statistics collectors
Expand All @@ -85,7 +85,7 @@ where
state: StateManager,
component_extractor: ComponentExtractor,
branching_heuristic: &'b mut B,
propagator: FTReachablePropagator<true>,
propagator: MixedPropagator,
mlimit: u64,
epsilon: f64,
) -> Self {
Expand Down Expand Up @@ -153,44 +153,36 @@ where
p_out += v_weight;
},
Ok(_) => {
let mut added_proba = f128!(1.0);
let mut removed_proba = f128!(1.0);
let mut has_removed = false;
self.distribution_out_vec.fill(0.0);
for d in self.propagator.unconstrained_distributions_iter() {
let mut p_unconstrained = 0.0;
if self.graph.distribution_number_false(d, &self.state) != 0 {
for variable in self.graph.distribution_variable_iter(d) {
if !self.graph.is_variable_fixed(variable, &self.state) {
p_unconstrained += self.graph.get_variable_weight(variable).unwrap();
}
let v = self.propagator.get_propagation_prob().clone();
if v != 0.0 {
let mut removed_proba = f128!(1.0);
let mut has_removed = false;
self.distribution_out_vec.fill(0.0);
for (d, variable, value) in self.propagator.assignments_iter().filter(|a| a.0 != distribution) {
let weight = self.graph.get_variable_weight(variable).unwrap();
if !value {
has_removed = true;
self.distribution_out_vec[d.0] += weight;
}
added_proba *= p_unconstrained;
}
}
for (d, variable, value) in self.propagator.assignments_iter().filter(|a| a.0 != distribution) {
let weight = self.graph.get_variable_weight(variable).unwrap();
if value {
added_proba *= weight;
} else {
has_removed = true;
self.distribution_out_vec[d.0] += weight;
if has_removed {
for v in self.distribution_out_vec.iter().copied() {
removed_proba *= 1.0 - v;
}
}
}
if has_removed {
for v in self.distribution_out_vec.iter().copied() {
removed_proba *= 1.0 - v;

let child_sol = self._solve(component);
p_in += child_sol.0 * &v;
p_out += child_sol.1 * &v + v_weight * (1.0 - removed_proba.clone());
p += child_sol.2 * &v;
if let Some(proba) = self.approximate_count(p_in.clone(), p_out.clone()) {
self.state.restore_state();
return (p_in, p_out, proba);
}
} else {
p_out += v_weight;
}

let child_sol = self._solve(component);
p_in += child_sol.0 * v_weight * &added_proba;
p_out += child_sol.1 * v_weight * &added_proba + v_weight * (1.0 - removed_proba.clone());
p += child_sol.2 * v_weight * &added_proba;
if let Some(proba) = self.approximate_count(p_in.clone(), p_out.clone()) {
self.state.restore_state();
return (p_in, p_out, proba);
}
}
};
self.state.restore_state();
Expand Down
6 changes: 3 additions & 3 deletions src/search/sequential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use search_trail::{StateManager, SaveAndRestore};
use crate::core::components::{ComponentExtractor, ComponentIndex};
use crate::core::graph::*;
use crate::heuristics::branching::BranchingDecision;
use crate::propagator::FTReachablePropagator;
use crate::propagator::SearchPropagator;
use crate::search::statistics::Statistics;
use crate::common::*;

Expand Down Expand Up @@ -58,7 +58,7 @@ where
/// Heuristics that decide on which distribution to branch next
branching_heuristic: &'b mut B,
/// The propagator
propagator: FTReachablePropagator<false>,
propagator: SearchPropagator,
/// Cache used to store results of sub-problems
cache: FxHashMap<CacheEntry, Float>,
/// Statistics collectors
Expand All @@ -76,7 +76,7 @@ where
state: StateManager,
component_extractor: ComponentExtractor,
branching_heuristic: &'b mut B,
propagator: FTReachablePropagator<false>,
propagator: SearchPropagator,
mlimit: u64,
) -> Self {
let cache = FxHashMap::default();
Expand Down
Loading

0 comments on commit 5986fe4

Please sign in to comment.