Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved BayesTree pruning #1293

Merged
merged 23 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c4d3889
prune hybrid gaussian ISAM more efficiently and without the need to s…
varunagrawal Sep 15, 2022
aef1669
Add labelformatter to Assignment for convenience
varunagrawal Sep 15, 2022
8b5586f
move prune method to HybridBayesTree class
varunagrawal Sep 16, 2022
dcad55c
optional maxNrLeaves for HybridGaussianISAM
varunagrawal Sep 16, 2022
5765926
optional ordering argument for HybridNonlinearISAM::update
varunagrawal Sep 16, 2022
a96c3db
minor fix
varunagrawal Sep 16, 2022
93528c3
Only eliminate variables that are in newFactors
varunagrawal Sep 16, 2022
aebcde9
add push_back to HybridBayesNet
varunagrawal Sep 16, 2022
9ef5c18
move renamed allDiscreteKeys and allContinuousKeys to HybridFactorGraph
varunagrawal Sep 17, 2022
12db5dd
undo changes
varunagrawal Sep 17, 2022
2f8a0f8
rename testHybridIncremental to testHybridGaussianISAM
varunagrawal Sep 19, 2022
c2ca426
rename allDiscreteKeys and allContinuousKeys to discreteKeys and cont…
varunagrawal Sep 20, 2022
2c4529f
Merge branch 'hybrid/improvements' into hybrid/improved-prune
varunagrawal Oct 3, 2022
5dfaa89
Merge branch 'hybrid/improved-prune' into hybrid/check-elimination
varunagrawal Oct 3, 2022
3407f97
Merge pull request #1294 from borglab/hybrid/check-elimination
varunagrawal Oct 3, 2022
ad32875
improved hybrid bayes net pruning
varunagrawal Oct 3, 2022
d6d44fc
minor cleanup
varunagrawal Oct 3, 2022
cae787a
Merge pull request #1300 from borglab/hybrid/improved-prune-2
varunagrawal Oct 4, 2022
bc8c77c
rename test file to correct form
varunagrawal Oct 4, 2022
8820bf2
Add test to expose bug in elimination with gaussian conditionals
varunagrawal Oct 4, 2022
9002b68
fix the bug
varunagrawal Oct 4, 2022
6238a1f
more docs for Switching example
varunagrawal Oct 4, 2022
fc9fc72
Merge pull request #1301 from borglab/hybrid/gaussian-conditional
varunagrawal Oct 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions gtsam/discrete/Assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@

/**
* @file Assignment.h
* @brief An assignment from labels to a discrete value index (size_t)
* @brief An assignment from labels to a discrete value index (size_t)
* @author Frank Dellaert
* @date Feb 5, 2012
*/

#pragma once

#include <functional>
#include <iostream>
#include <map>
#include <sstream>
#include <utility>
#include <vector>

Expand All @@ -32,13 +34,30 @@ namespace gtsam {
*/
template <class L>
class Assignment : public std::map<L, size_t> {
/**
* @brief Default method used by `labelFormatter` or `valueFormatter` when
* printing.
*
* @param x The value passed to format.
* @return std::string
*/
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}

public:
using std::map<L, size_t>::operator=;

void print(const std::string& s = "Assignment: ") const {
void print(const std::string& s = "Assignment: ",
const std::function<std::string(L)>& labelFormatter =
&DefaultFormatter) const {
std::cout << s << ": ";
for (const typename Assignment::value_type& keyValue : *this)
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
for (const typename Assignment::value_type& keyValue : *this) {
std::cout << "(" << labelFormatter(keyValue.first) << ", "
<< keyValue.second << ")";
}
std::cout << std::endl;
}

Expand Down
9 changes: 5 additions & 4 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,12 @@ void GaussianMixture::print(const std::string &s,
"", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (gf && !gf->empty())
if (gf && !gf->empty()) {
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
return rd.str();
} else {
return "nullptr";
}
});
}

Expand Down
28 changes: 26 additions & 2 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,32 @@ static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;

// The canonical decision tree factor which will get the discrete conditionals
// added to it.
DecisionTreeFactor dtFactor;

for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional());
dtFactor = dtFactor * f;
}
}
return boost::make_shared<DecisionTreeFactor>(dtFactor);
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr discreteFactor =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));

/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
Expand Down
16 changes: 13 additions & 3 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
}

using Base::push_back;

/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atMixture(size_t i) const;

Expand Down Expand Up @@ -109,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;

/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
protected:
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;

public:
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves) const;

/// @}

Expand Down
61 changes: 59 additions & 2 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ struct HybridAssignmentData {
gaussianbayesTree_(gbt) {}

/**
* @brief A function used during tree traversal that operators on each node
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The HybridAssignmentData from the parent node.
* @return HybridAssignmentData
* @return HybridAssignmentData which is passed to the children.
*/
static HybridAssignmentData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& node,
Expand Down Expand Up @@ -144,4 +144,61 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
return result;
}

/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner());

DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;

/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteFactor;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteFactor(prunedDiscreteFactor) {}

/**
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The data from the parent node.
* @return HybridPrunerData which is passed to the children.
*/
static HybridPrunerData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& clique,
HybridPrunerData& parentData) {
// Get the conditional
HybridConditional::shared_ptr conditional = clique->conditional();

// If conditional is hybrid, we prune it.
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();

// Check if the number of discrete keys match,
// else we get an assignment error.
// TODO(Varun) Update prune method to handle assignment subset?
if (gaussianMixture->discreteKeys() ==
parentData.prunedDiscreteFactor.discreteKeys()) {
gaussianMixture->prune(parentData.prunedDiscreteFactor);
}
}
return parentData;
}
};

HybridPrunerData rootData(prunedDiscreteFactor, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
visitorPost);
}
}

} // namespace gtsam
7 changes: 7 additions & 0 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
*/
VectorValues optimize(const DiscreteValues& assignment) const;

/**
* @brief Prune the underlying Bayes tree.
*
* @param maxNumberLeaves The max number of leaf nodes to keep.
*/
void prune(const size_t maxNumberLeaves);

/// @}

private:
Expand Down
2 changes: 0 additions & 2 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@

namespace gtsam {

class HybridGaussianFactorGraph;

/**
* Hybrid Conditional Density
*
Expand Down
22 changes: 22 additions & 0 deletions gtsam/hybrid/HybridFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
push_hybrid(p);
}
}

/// Get all the discrete keys in the factor graph.
const KeySet discreteKeys() const {
KeySet discrete_keys;
for (auto& factor : factors_) {
for (const DiscreteKey& k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
return discrete_keys;
}

/// Get all the continuous keys in the factor graph.
const KeySet continuousKeys() const {
KeySet keys;
for (auto& factor : factors_) {
for (const Key& key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}
};

} // namespace gtsam
32 changes: 7 additions & 25 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ GaussianMixtureFactor::Sum sumFrontals(
}

} else if (f->isContinuous()) {
deferredFactors.push_back(
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
deferredFactors.push_back(gf->inner());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
deferredFactors.push_back(cg->asGaussian());
}

} else if (f->isDiscrete()) {
// Don't do anything for discrete-only factors
Expand Down Expand Up @@ -404,31 +408,9 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}

/* ************************************************************************ */
const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const {
KeySet discrete_keys;
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
return discrete_keys;
}

/* ************************************************************************ */
const KeySet HybridGaussianFactorGraph::getContinuousKeys() const {
KeySet keys;
for (auto &factor : factors_) {
for (const Key &key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = getDiscreteKeys();
KeySet discrete_keys = discreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
Expand Down
6 changes: 0 additions & 6 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
}
}

/// Get all the discrete keys in the factor graph.
const KeySet getDiscreteKeys() const;

/// Get all the continuous keys in the factor graph.
const KeySet getContinuousKeys() const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
Expand Down
Loading