Skip to content

Commit

Permalink
Merge pull request #1574 from borglab/feature/improved_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Jul 16, 2023
2 parents eca51af + 2453c37 commit 13c7daf
Show file tree
Hide file tree
Showing 23 changed files with 700 additions and 193 deletions.
64 changes: 56 additions & 8 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
namespace gtsam {

/**
* Algebraic Decision Trees fix the range to double
* Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether?
* An algebraic decision tree fixes the range of a DecisionTree to double.
* Just has some nice constructors and some syntactic sugar.
* TODO(dellaert): consider eliminating this class altogether?
*
* @ingroup discrete
*/
Expand Down Expand Up @@ -80,20 +80,62 @@ namespace gtsam {
AlgebraicDecisionTree(const L& label, double y1, double y2)
: Base(label, y1, y2) {}

/** Create a new leaf function splitting on a variable */
/**
* @brief Create a new leaf function splitting on a variable
*
* @param labelC: The label with cardinality 2
* @param y1: The value for the first key
* @param y2: The value for the second key
*
* Example:
* @code{.cpp}
* std::pair<string, size_t> A {"a", 2};
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
* @endcode
*/
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
double y2)
: Base(labelC, y1, y2) {}

/** Create from keys and vector table */
/**
* @brief Create from keys with cardinalities and a vector table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param ys: The vector table
*
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
* respectively, and a vector table of size 12:
* @code{.cpp}
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
* const vector<double> cpt{
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
* @endcode
* The table is given in the following order:
* A=0, B=0, C=0
* A=0, B=0, C=1
* ...
* A=1, B=1, C=1
* Hence, the first line in the table is for A==0, and the second for A==1.
* In each line, the first two entries are for B==0, the next two for B==1,
* and the last two for B==2. Each pair is for a C value of 0 and 1.
*/
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::vector<double>& ys) {
const std::vector<double>& ys) {
this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}

/** Create from keys and string table */
/**
* @brief Create from keys and string table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param table: The string table, given as a string of doubles.
*
* @note Table needs to be in same order as the vector table in the other constructor.
*/
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
Expand All @@ -108,7 +150,13 @@ namespace gtsam {
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}

/** Create a new function splitting on a variable */
/**
* @brief Create a range of decision trees, splitting on a single variable.
*
* @param begin: Iterator to beginning of a range of decision trees
* @param end: Iterator to end of a range of decision trees
* @param label: The label to split on
*/
template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
: Base(nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ namespace gtsam {
// B=1
// A=0: 3
// A=1: 4
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
// Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
// exactly the same tree as above: the highest label is always the root.
// However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y>
Expand Down
31 changes: 26 additions & 5 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,23 @@
namespace gtsam {

/**
* Decision Tree
* L = label for variables
* Y = function range (any algebra), e.g., bool, int, double
* @brief a decision tree is a function from assignments to values.
* @tparam L label for variables
* @tparam Y function range (any algebra), e.g., bool, int, double
*
* After creating a decision tree on some variables, the tree can be evaluated
* on an assignment to those variables. Example:
*
* @code{.cpp}
* // Create a decision stump one one variable 'a' with values 10 and 20.
* DecisionTree<char, int> tree('a', 10, 20);
*
* // Evaluate the tree on an assignment to the variable.
* int value0 = tree({{'a', 0}}); // value0 = 10
* int value1 = tree({{'a', 1}}); // value1 = 20
* @endcode
*
* More examples can be found in testDecisionTree.cpp
*
* @ingroup discrete
*/
Expand Down Expand Up @@ -132,7 +146,8 @@ namespace gtsam {
NodePtr root_;

protected:
/** Internal recursive function to create from keys, cardinalities,
/**
* Internal recursive function to create from keys, cardinalities,
* and Y values
*/
template<typename It, typename ValueIt>
Expand Down Expand Up @@ -163,7 +178,13 @@ namespace gtsam {
/** Create a constant */
explicit DecisionTree(const Y& y);

/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
/**
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
*
* @param label The variable to split on.
* @param y1 The value for the first assignment.
* @param y2 The value for the second assignment.
*/
DecisionTree(const L& label, const Y& y1, const Y& y2);

/** Allow Label+Cardinality for convenience */
Expand Down
41 changes: 38 additions & 3 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,46 @@ namespace gtsam {
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);

/** Constructor from doubles */
/**
* @brief Constructor from doubles
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* const std::vector<double> table {2, 5, 3, 6, 4, 7};
* DecisionTreeFactor f1({X, Y}, table);
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
const std::vector<double>& table);

/** Constructor from string */
/**
* @brief Constructor from string
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7");
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);

/// Single-key specialization
Expand Down
5 changes: 5 additions & 0 deletions gtsam/discrete/DiscreteBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique

//** evaluate conditional probability of subtree for given DiscreteValues */
double evaluate(const DiscreteValues& values) const;

//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}
};

/* ************************************************************************* */
Expand Down
39 changes: 25 additions & 14 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,30 @@ class DiscreteJunctionTree;

/**
* @brief Main elimination function for DiscreteFactorGraph.
*
* @param factors
* @param keys
* @return GTSAM_EXPORT
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting conditional and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);

/**
* @brief Alternate elimination function for that creates non-normalized lookup tables.
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting lookup table and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);

/* ************************************************************************* */
template<> struct EliminationTraits<DiscreteFactorGraph>
{
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
Expand All @@ -61,12 +75,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree

/// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys);
}

/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
Expand All @@ -75,7 +91,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
}
};

/* ************************************************************************* */
/**
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
Expand Down Expand Up @@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph

/** Implicit copy/downcast constructor to override explicit template container
* constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
template <class DERIVED_FACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}

/// Destructor
virtual ~DiscreteFactorGraph() {}
Expand Down Expand Up @@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @}
}; // \ DiscreteFactorGraph

std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);

/// traits
template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
Expand Down
2 changes: 2 additions & 0 deletions gtsam/discrete/DiscreteJunctionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,6 @@ namespace gtsam {
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
};

/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
}
5 changes: 5 additions & 0 deletions gtsam/discrete/DiscreteValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
/// @}
};

/// Free version of CartesianProduct.
inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
return DiscreteValues::CartesianProduct(keys);
}

/// Free version of markdown.
std::string markdown(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
Expand Down
Loading

0 comments on commit 13c7daf

Please sign in to comment.