Skip to content

Commit

Permalink
Merge pull request #1309 from borglab/varun/pruning-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Nov 3, 2022
2 parents 719e4f9 + 7f8bd8a commit 256c664
Show file tree
Hide file tree
Showing 15 changed files with 798 additions and 320 deletions.
1 change: 0 additions & 1 deletion gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ void GaussianMixture::print(const std::string &s,
}

/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
Expand Down
3 changes: 3 additions & 0 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ class GTSAM_EXPORT GaussianMixture
Sum add(const Sum &sum) const;
};

/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);

// traits
template <>
struct traits<GaussianMixture> : public Testable<GaussianMixture> {};
Expand Down
95 changes: 91 additions & 4 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,100 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());

auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) {
return 0.0;
} else {
return probability;
}
} else {
std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
conditionalKeySet.begin(), conditionalKeySet.end(),
std::back_inserter(set_diff));

const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());

// If any one of the sub-branches are non-zero,
// we need this probability.
if (decisionTree(augmented_values) > 0.0) {
return probability;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return 0.0;
}
};
return pruner;
}

/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree->keys();

for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// std::cout << demangle(typeid(conditional).name()) << std::endl;
auto discrete = conditional->asDiscreteConditional();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());

// Apply prunerFunc to the underlying AlgebraicDecisionTree
auto discreteTree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));

// Create the new (hybrid) conditional
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);

// Add it back to the BayesNet
this->at(i) = conditional;
}
}
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr discreteFactor =
const DecisionTreeFactor::shared_ptr decisionTree =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));

this->updateDiscreteConditionals(decisionTree);

/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
Expand All @@ -59,7 +146,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
HybridBayesNet prunedBayesNetFragment;

// Go through all the conditionals in the
// Bayes Net and prune them as per discreteFactor.
// Bayes Net and prune them as per decisionTree.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);

Expand All @@ -69,7 +156,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Make a copy of the gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*discreteFactor);
prunedGaussianMixture->prune(*decisionTree);

// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
Expand Down
11 changes: 9 additions & 2 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;

protected:
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
Expand All @@ -123,11 +122,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {

public:
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves) const;
HybridBayesNet prune(size_t maxNrLeaves);

/// @}

private:
/**
* @brief Update the discrete conditionals with the pruned versions.
*
* @param prunedDecisionTree
*/
void updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {

/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner());
auto decisionTree =
this->roots_.at(0)->conditional()->asDiscreteConditional();

DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
Expand Down
112 changes: 112 additions & 0 deletions gtsam/hybrid/HybridSmoother.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file HybridSmoother.cpp
* @brief An incremental smoother for hybrid factor graphs
* @author Varun Agrawal
* @date October 2022
*/

#include <gtsam/hybrid/HybridSmoother.h>

#include <algorithm>
#include <unordered_set>

namespace gtsam {

/* ************************************************************************* */
void HybridSmoother::update(HybridGaussianFactorGraph graph,
const Ordering &ordering,
boost::optional<size_t> maxNrLeaves) {
// Add the necessary conditionals from the previous timestep(s).
std::tie(graph, hybridBayesNet_) =
addConditionals(graph, hybridBayesNet_, ordering);

// Eliminate.
auto bayesNetFragment = graph.eliminateSequential(ordering);

/// Prune
if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment.
HybridBayesNet prunedBayesNetFragment =
bayesNetFragment->prune(*maxNrLeaves);
// Set the bayes net fragment to the pruned version
bayesNetFragment =
boost::make_shared<HybridBayesNet>(prunedBayesNetFragment);
}

// Add the partial bayes net to the posterior bayes net.
hybridBayesNet_.push_back<HybridBayesNet>(*bayesNetFragment);
}

/* ************************************************************************* */
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
const HybridBayesNet &originalHybridBayesNet,
const Ordering &ordering) const {
HybridGaussianFactorGraph graph(originalGraph);
HybridBayesNet hybridBayesNet(originalHybridBayesNet);

// If we are not at the first iteration, means we have conditionals to add.
if (!hybridBayesNet.empty()) {
// We add all relevant conditional mixtures on the last continuous variable
// in the previous `hybridBayesNet` to the graph

// Conditionals to remove from the bayes net
// since the conditional will be updated.
std::vector<HybridConditional::shared_ptr> conditionals_to_erase;

// New conditionals to add to the graph
gtsam::HybridBayesNet newConditionals;

// NOTE(Varun) Using a for-range loop doesn't work since some of the
// conditionals are invalid pointers
for (size_t i = 0; i < hybridBayesNet.size(); i++) {
auto conditional = hybridBayesNet.at(i);

for (auto &key : conditional->frontals()) {
if (std::find(ordering.begin(), ordering.end(), key) !=
ordering.end()) {
newConditionals.push_back(conditional);
conditionals_to_erase.push_back(conditional);

break;
}
}
}
// Remove conditionals at the end so we don't affect the order in the
// original bayes net.
for (auto &&conditional : conditionals_to_erase) {
auto it = find(hybridBayesNet.begin(), hybridBayesNet.end(), conditional);
hybridBayesNet.erase(it);
}

graph.push_back(newConditionals);
// newConditionals.print("\n\n\nNew Conditionals to add back");
}
return {graph, hybridBayesNet};
}

/* ************************************************************************* */
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture(
size_t index) const {
return boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet_.at(index));
}

/* ************************************************************************* */
const HybridBayesNet &HybridSmoother::hybridBayesNet() const {
return hybridBayesNet_;
}

} // namespace gtsam
73 changes: 73 additions & 0 deletions gtsam/hybrid/HybridSmoother.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file HybridSmoother.h
* @brief An incremental smoother for hybrid factor graphs
* @author Varun Agrawal
* @date October 2022
*/

#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>

namespace gtsam {

class HybridSmoother {
private:
HybridBayesNet hybridBayesNet_;
HybridGaussianFactorGraph remainingFactorGraph_;

public:
/**
* Given new factors, perform an incremental update.
* The relevant densities in the `hybridBayesNet` will be added to the input
* graph (fragment), and then eliminated according to the `ordering`
* presented. The remaining factor graph contains Gaussian mixture factors
* that are not connected to the variables in the ordering, or a single
* discrete factor on all discrete keys, plus all discrete factors in the
* original graph.
*
* \note If maxComponents is given, we look at the discrete factor resulting
* from this elimination, and prune it and the Gaussian components
* corresponding to the pruned choices.
*
* @param graph The new factors, should be linear only
* @param ordering The ordering for elimination, only continuous vars are
* allowed
* @param maxNrLeaves The maximum number of leaves in the new discrete factor,
* if applicable
*/
void update(HybridGaussianFactorGraph graph, const Ordering& ordering,
boost::optional<size_t> maxNrLeaves = boost::none);

/**
* @brief Add conditionals from previous timestep as part of liquefication.
*
* @param graph The new factor graph for the current time step.
* @param hybridBayesNet The hybrid bayes net containing all conditionals so
* far.
* @param ordering The elimination ordering.
* @return std::pair<HybridGaussianFactorGraph, HybridBayesNet>
*/
std::pair<HybridGaussianFactorGraph, HybridBayesNet> addConditionals(
const HybridGaussianFactorGraph& graph,
const HybridBayesNet& hybridBayesNet, const Ordering& ordering) const;

/// Get the Gaussian Mixture from the Bayes Net posterior at `index`.
GaussianMixture::shared_ptr gaussianMixture(size_t index) const;

/// Return the Bayes Net posterior.
const HybridBayesNet& hybridBayesNet() const;
};

}; // namespace gtsam
Loading

0 comments on commit 256c664

Please sign in to comment.