Skip to content

Commit

Permalink
Generic solver over compilation to avoid unnecessary memory usage in …
Browse files Browse the repository at this point in the history
…search
  • Loading branch information
AlexandreDubray committed Aug 20, 2024
1 parent 1b24457 commit 6d749b5
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 39 deletions.
18 changes: 11 additions & 7 deletions src/learning/learner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ impl <const S: bool> Learner<S> {
}

fn recompile_dacs(&mut self, branching: Branching, approx:ApproximateMethod, compile_timeout: u64) {
let distributions = self.get_softmaxed_array().iter().map(|d| d.iter().map(|f| f.to_f64()).collect::<Vec<f64>>()).collect();
let distributions: Vec<Vec<f64>> = self.get_softmaxed_array().iter().map(|d| d.iter().map(|f| f.to_f64()).collect::<Vec<f64>>()).collect();
let mut train_dacs = generate_dacs(&self.clauses, &distributions, branching, self.epsilon, approx, compile_timeout);
let mut train_data = vec![];
let mut eps = 0.0;
Expand Down Expand Up @@ -423,7 +423,7 @@ pub fn softmax(x: &[f64]) -> Vec<Float> {
}

/// Generates a vector of optional Dacs from a list of input files
pub fn generate_dacs<R: SemiRing>(queries_clauses: &Vec<Vec<Vec<isize>>>, distributions: &Vec<Vec<f64>>,branching: Branching, epsilon: f64, approx: ApproximateMethod, timeout: u64) -> Vec<Dac<R>> {
pub fn generate_dacs<R: SemiRing>(queries_clauses: &Vec<Vec<Vec<isize>>>, distributions: &[Vec<f64>],branching: Branching, epsilon: f64, approx: ApproximateMethod, timeout: u64) -> Vec<Dac<R>> {
queries_clauses.par_iter().map(|clauses| {
// We compile the input. This can either be a .cnf file or a fdac file.
// If the file is a fdac file, then we read directly from it
Expand All @@ -432,18 +432,22 @@ pub fn generate_dacs<R: SemiRing>(queries_clauses: &Vec<Vec<Vec<isize>>>, distri
let parameters = SolverParameters::new(u64::MAX, epsilon, timeout);
let propagator = Propagator::new(&mut state);
let component_extractor = ComponentExtractor::new(&problem, &mut state);
let compiler = generic_solver(problem, state, component_extractor, branching, propagator, parameters, false);
let compiler = generic_solver(problem, state, component_extractor, branching, propagator, parameters, false, true);
match approx {
ApproximateMethod::Bounds => {
match compiler {
crate::GenericSolver::SMinInDegree(mut s) => s.compile(false),
crate::GenericSolver::QMinInDegree(mut s) => s.compile(false),
crate::GenericSolver::SMinInDegreeCompile(mut s) => s.compile(false),
crate::GenericSolver::QMinInDegreeCompile(mut s) => s.compile(false),
crate::GenericSolver::SMinInDegreeSearch(_) => panic!("Non compile solver used for learning"),
crate::GenericSolver::QMinInDegreeSearch(_) => panic!("Non compile solver used for learning"),
}
},
ApproximateMethod::LDS => {
match compiler {
crate::GenericSolver::SMinInDegree(mut s) => s.compile(true),
crate::GenericSolver::QMinInDegree(mut s) => s.compile(true),
crate::GenericSolver::SMinInDegreeCompile(mut s) => s.compile(true),
crate::GenericSolver::QMinInDegreeCompile(mut s) => s.compile(true),
crate::GenericSolver::SMinInDegreeSearch(_) => panic!("Non compile solver used for learning"),
crate::GenericSolver::QMinInDegreeSearch(_) => panic!("Non compile solver used for learning"),
}
},

Expand Down
84 changes: 57 additions & 27 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,23 @@ pub fn search(args: Args) -> f64 {
let parser = parser_from_input(args.input.clone(), args.evidence.clone());
let problem = parser.problem_from_file(&mut state);
let component_extractor = ComponentExtractor::new(&problem, &mut state);
let solver = generic_solver(problem, state, component_extractor, args.branching, propagator, parameters, args.statistics);
let solver = generic_solver(problem, state, component_extractor, args.branching, propagator, parameters, args.statistics, false);

let solution = match args.approx {
ApproximateMethod::Bounds => {
match solver {
GenericSolver::SMinInDegree(mut solver) => solver.search(false),
GenericSolver::QMinInDegree(mut solver) => solver.search(false),
GenericSolver::SMinInDegreeSearch(mut solver) => solver.search(false),
GenericSolver::QMinInDegreeSearch(mut solver) => solver.search(false),
GenericSolver::SMinInDegreeCompile(_) => panic!("Non search solver used in search"),
GenericSolver::QMinInDegreeCompile(_) => panic!("Non search solver used in search"),
}
},
ApproximateMethod::LDS => {
match solver {
GenericSolver::SMinInDegree(mut solver) => solver.search(true),
GenericSolver::QMinInDegree(mut solver) => solver.search(true),
GenericSolver::SMinInDegreeSearch(mut solver) => solver.search(true),
GenericSolver::QMinInDegreeSearch(mut solver) => solver.search(true),
GenericSolver::SMinInDegreeCompile(_) => panic!("Non search solver used in search"),
GenericSolver::QMinInDegreeCompile(_) => panic!("Non search solver used in search"),
}
},
};
Expand All @@ -242,10 +246,12 @@ pub fn pysearch(args: Args, distributions: &[Vec<f64>], clauses: &[Vec<isize>])
let propagator = Propagator::new(&mut state);
let problem = create_problem(distributions, clauses, &mut state);
let component_extractor = ComponentExtractor::new(&problem, &mut state);
let solver = generic_solver(problem, state, component_extractor, args.branching, propagator, parameters, args.statistics);
let solver = generic_solver(problem, state, component_extractor, args.branching, propagator, parameters, args.statistics, false);
let solution = match solver {
GenericSolver::SMinInDegree(mut solver) => solver.search(false),
GenericSolver::QMinInDegree(mut solver) => solver.search(false),
GenericSolver::SMinInDegreeSearch(mut solver) => solver.search(false),
GenericSolver::QMinInDegreeSearch(mut solver) => solver.search(false),
GenericSolver::SMinInDegreeCompile(_) => panic!("Non search solver used in search"),
GenericSolver::QMinInDegreeCompile(_) => panic!("Non search solver used in search"),
};
solution.print();
solution.bounds()
Expand All @@ -258,19 +264,23 @@ pub fn compile(args: Args) -> f64 {
let parser = parser_from_input(args.input.clone(), args.evidence.clone());
let problem = parser.problem_from_file(&mut state);
let component_extractor = ComponentExtractor::new(&problem, &mut state);
let solver = generic_solver(problem, state, component_extractor, args.branching, propagator, parameters, args.statistics);
let solver = generic_solver(problem, state, component_extractor, args.branching, propagator, parameters, args.statistics, true);

let mut ac: Dac<Float> = match args.approx {
ApproximateMethod::Bounds => {
match solver {
GenericSolver::SMinInDegree(mut solver) => solver.compile(false),
GenericSolver::QMinInDegree(mut solver) => solver.compile(false),
GenericSolver::SMinInDegreeCompile(mut solver) => solver.compile(false),
GenericSolver::QMinInDegreeCompile(mut solver) => solver.compile(false),
GenericSolver::SMinInDegreeSearch(_) => panic!("Non compile solver used in compilation"),
GenericSolver::QMinInDegreeSearch(_) => panic!("Non compile solver used in compilation"),
}
},
ApproximateMethod::LDS => {
match solver {
GenericSolver::SMinInDegree(mut solver) => solver.compile(true),
GenericSolver::QMinInDegree(mut solver) => solver.compile(true),
GenericSolver::SMinInDegreeCompile(mut solver) => solver.compile(true),
GenericSolver::QMinInDegreeCompile(mut solver) => solver.compile(true),
GenericSolver::SMinInDegreeSearch(_) => panic!("Non compile solver used in compilation"),
GenericSolver::QMinInDegreeSearch(_) => panic!("Non compile solver used in compilation"),
}
},
};
Expand Down Expand Up @@ -367,24 +377,44 @@ impl std::fmt::Display for Loss {
}

pub enum GenericSolver {
SMinInDegree(Solver<MinInDegree, true>),
QMinInDegree(Solver<MinInDegree, false>),
SMinInDegreeSearch(Solver<MinInDegree, true, false>),
QMinInDegreeSearch(Solver<MinInDegree, false, false>),
SMinInDegreeCompile(Solver<MinInDegree, true, true>),
QMinInDegreeCompile(Solver<MinInDegree, false, true>),
}

pub fn generic_solver(problem: Problem, state: StateManager, component_extractor: ComponentExtractor, branching: Branching, propagator: Propagator, parameters: SolverParameters, stat: bool) -> GenericSolver {
if stat {
match branching {
Branching::MinInDegree => {
let solver = Solver::<MinInDegree, true>::new(problem, state, component_extractor, Box::<MinInDegree>::default(), propagator, parameters);
GenericSolver::SMinInDegree(solver)
},
pub fn generic_solver(problem: Problem, state: StateManager, component_extractor: ComponentExtractor, branching: Branching, propagator: Propagator, parameters: SolverParameters, stat: bool, compile: bool) -> GenericSolver {
if compile {
if stat {
match branching {
Branching::MinInDegree => {
let solver = Solver::<MinInDegree, true, true>::new(problem, state, component_extractor, Box::<MinInDegree>::default(), propagator, parameters);
GenericSolver::SMinInDegreeCompile(solver)
},
}
} else {
match branching {
Branching::MinInDegree => {
let solver = Solver::<MinInDegree, false, true>::new(problem, state, component_extractor, Box::<MinInDegree>::default(), propagator, parameters);
GenericSolver::QMinInDegreeCompile(solver)
},
}
}
} else {
match branching {
Branching::MinInDegree => {
let solver = Solver::<MinInDegree, false>::new(problem, state, component_extractor, Box::<MinInDegree>::default(), propagator, parameters);
GenericSolver::QMinInDegree(solver)
},
if stat {
match branching {
Branching::MinInDegree => {
let solver = Solver::<MinInDegree, true, false>::new(problem, state, component_extractor, Box::<MinInDegree>::default(), propagator, parameters);
GenericSolver::SMinInDegreeSearch(solver)
},
}
} else {
match branching {
Branching::MinInDegree => {
let solver = Solver::<MinInDegree, false, false>::new(problem, state, component_extractor, Box::<MinInDegree>::default(), propagator, parameters);
GenericSolver::QMinInDegreeSearch(solver)
},
}
}
}
}
16 changes: 11 additions & 5 deletions src/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type UnconstrainedDistribution = (DistributionIndex, Vec<VariableIndex>);
/// Finally, the compiler is able to create an arithmetic circuit for any semi-ring. Currently
/// implemented are the probability semi-ring (the default) and tensor semi-ring, which uses torch
/// tensors (useful for automatic differentiation in learning).
pub struct Solver<B: BranchingDecision, const S: bool> {
pub struct Solver<B: BranchingDecision, const S: bool, const C: bool> {
/// Implication problem of the (Horn) clauses in the input
problem: Problem,
/// Manages (save/restore) the states (e.g., reversible primitive types)
Expand All @@ -69,10 +69,12 @@ pub struct Solver<B: BranchingDecision, const S: bool> {
preproc_out: Option<f64>,
/// Parameters of the solving
parameters: SolverParameters,
/// The caches present in the cache. Used during compilation to reconstruct the AC from the
/// cache (follow the children of a node)
cache_keys: Vec<CacheKey>,
}

impl<B: BranchingDecision, const S: bool> Solver<B, S> {
impl<B: BranchingDecision, const S: bool, const C: bool> Solver<B, S, C> {
pub fn new(
problem: Problem,
state: StateManager,
Expand Down Expand Up @@ -205,7 +207,9 @@ impl<B: BranchingDecision, const S: bool> Solver<B, S> {
let mut cache_entry = self.cache.remove(&cache_key).unwrap_or_else(|| {
self.statistics.cache_miss();
let cache_key_index = self.cache_keys.len();
self.cache_keys.push(cache_key.clone());
if C {
self.cache_keys.push(cache_key.clone());
}
CacheEntry::new((F128!(0.0), F128!(0.0)), 0, None, FxHashMap::default(), cache_key_index)
});
if cache_entry.distribution.is_none() {
Expand Down Expand Up @@ -282,7 +286,9 @@ impl<B: BranchingDecision, const S: bool> Solver<B, S> {
is_product_sat = false;
break;
}
child_entry.add_key(sub_solution.cache_index);
if C {
child_entry.add_key(sub_solution.cache_index);
}
}
}
if is_product_sat && prod_p_in > 0.0 {
Expand Down Expand Up @@ -310,7 +316,7 @@ impl<B: BranchingDecision, const S: bool> Solver<B, S> {
}
}

impl<B: BranchingDecision, const S: bool> Solver<B, S> {
impl<B: BranchingDecision, const S: bool, const C: bool> Solver<B, S, C> {

pub fn compile<R: SemiRing>(&mut self, is_lds: bool) -> Dac<R> {
let start = Instant::now();
Expand Down

0 comments on commit 6d749b5

Please sign in to comment.