From febeacd68686ed0b7ced72458eb0b31a196bdab7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 4 Jun 2023 15:40:02 +0100 Subject: [PATCH 01/15] Improved documentation and tests --- gtsam/discrete/AlgebraicDecisionTree.h | 64 +++++++++++++++++++--- gtsam/discrete/DecisionTree-inl.h | 2 +- gtsam/discrete/DecisionTree.h | 31 +++++++++-- gtsam/discrete/tests/testDecisionTree.cpp | 66 ++++++++++++++++++----- 4 files changed, 137 insertions(+), 26 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index b3f0d69b0e..cd77e41f8e 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -28,9 +28,9 @@ namespace gtsam { /** - * Algebraic Decision Trees fix the range to double - * Just has some nice constructors and some syntactic sugar - * TODO: consider eliminating this class altogether? + * An algebraic decision tree fixes the range of a DecisionTree to double. + * Just has some nice constructors and some syntactic sugar. + * TODO(dellaert): consider eliminating this class altogether? * * @ingroup discrete */ @@ -80,20 +80,62 @@ namespace gtsam { AlgebraicDecisionTree(const L& label, double y1, double y2) : Base(label, y1, y2) {} - /** Create a new leaf function splitting on a variable */ + /** + * @brief Create a new leaf function splitting on a variable + * + * @param labelC: The label with cardinality 2 + * @param y1: The value for the first key + * @param y2: The value for the second key + * + * Example: + * @code{.cpp} + * std::pair A {"a", 2}; + * AlgebraicDecisionTree a(A, 0.6, 0.4); + * @endcode + */ AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : Base(labelC, y1, y2) {} - /** Create from keys and vector table */ + /** + * @brief Create from keys with cardinalities and a vector table + * + * @param labelCs: The keys, with cardinalities, given as pairs + * @param ys: The vector table + * + * Example with three keys, A, B, and C, with cardinalities 2, 3, and 2, + * respectively, and a vector table of size 12: + * @code{.cpp} + * DiscreteKey A(0, 2), B(1, 3), C(2, 2); + * const vector cpt{ + * 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // + * 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10}; + * AlgebraicDecisionTree expected(A & B & C, cpt); + * @endcode + * The table is given in the following order: + * A=0, B=0, C=0 + * A=0, B=0, C=1 + * ... + * A=1, B=1, C=1 + * Hence, the first line in the table is for A==0, and the second for A==1. + * In each line, the first two entries are for B==0, the next two for B==1, + * and the last two for B==2. Each pair is for a C value of 0 and 1. + */ AlgebraicDecisionTree // (const std::vector& labelCs, - const std::vector& ys) { + const std::vector& ys) { this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /** Create from keys and string table */ + /** + * @brief Create from keys and string table + * + * @param labelCs: The keys, with cardinalities, given as pairs + * @param table: The string table, given as a string of doubles. + * + * @note Table needs to be in same order as the vector table in the other constructor. + */ AlgebraicDecisionTree // (const std::vector& labelCs, const std::string& table) { @@ -108,7 +150,13 @@ namespace gtsam { Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /** Create a new function splitting on a variable */ + /** + * @brief Create a range of decision trees, splitting on a single variable. + * + * @param begin: Iterator to beginning of a range of decision trees + * @param end: Iterator to end of a range of decision trees + * @param label: The label to split on + */ template AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : Base(nullptr) { diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 9f3d5e8f95..4d1670bb74 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -622,7 +622,7 @@ namespace gtsam { // B=1 // A=0: 3 // A=1: 4 - // Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce + // Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce // exactly the same tree as above: the highest label is always the root. // However, it will be *way* faster if labels are given highest to lowest. template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index a8764a98f7..06e945cf9f 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -37,9 +37,23 @@ namespace gtsam { /** - * Decision Tree - * L = label for variables - * Y = function range (any algebra), e.g., bool, int, double + * @brief a decision tree is a function from assignments to values. + * @tparam L label for variables + * @tparam Y function range (any algebra), e.g., bool, int, double + * + * After creating a decision tree on some variables, the tree can be evaluated + * on an assignment to those variables. Example: + * + * @code{.cpp} + * // Create a decision stump one one variable 'a' with values 10 and 20. + * DecisionTree tree('a', 10, 20); + * + * // Evaluate the tree on an assignment to the variable. + * int value0 = tree({{'a', 0}}); // value0 = 10 + * int value1 = tree({{'a', 1}}); // value1 = 20 + * @endcode + * + * More examples can be found in testDecisionTree.cpp * * @ingroup discrete */ @@ -132,7 +146,8 @@ namespace gtsam { NodePtr root_; protected: - /** Internal recursive function to create from keys, cardinalities, + /** + * Internal recursive function to create from keys, cardinalities, * and Y values */ template @@ -163,7 +178,13 @@ namespace gtsam { /** Create a constant */ explicit DecisionTree(const Y& y); - /// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` + /** + * @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` + * + * @param label The variable to split on. + * @param y1 The value for the first assignment. + * @param y2 The value for the second assignment. + */ DecisionTree(const L& label, const Y& y1, const Y& y2); /** Allow Label+Cardinality for convenience */ diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 2f385263c1..fbcecb5abb 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -71,6 +71,19 @@ struct traits : public Testable {}; GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) +/* ************************************************************************** */ +// Test char labels and int range +/* ************************************************************************** */ + +// Create a decision stump one one variable 'a' with values 10 and 20. +TEST(DecisionTree, constructor) { + DecisionTree tree('a', 10, 20); + + // Evaluate the tree on an assignment to the variable. + EXPECT_LONGS_EQUAL(10, tree({{'a', 0}})); + EXPECT_LONGS_EQUAL(20, tree({{'a', 1}})); +} + /* ************************************************************************** */ // Test string labels and int range /* ************************************************************************** */ @@ -114,18 +127,47 @@ struct Ring { static inline int mul(const int& a, const int& b) { return a * b; } }; +/* ************************************************************************** */ +// Check that creating decision trees respects key order. +TEST(DecisionTree, constructor_order) { + // Create labels + string A("A"), B("B"); + + const std::vector ys1 = {1, 2, 3, 4}; + DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A! + + const std::vector ys2 = {1, 3, 2, 4}; + DT tree2({{A, 2}, {B, 2}}, ys2); // slower version ! + + // Both trees will be the same, tree is order from high to low labels. + // Choice(B) + // 0 Choice(A) + // 0 0 Leaf 1 + // 0 1 Leaf 2 + // 1 Choice(A) + // 1 0 Leaf 3 + // 1 1 Leaf 4 + + EXPECT(tree2.equals(tree1)); + + // Check the values are as expected by calling the () operator: + EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}})); + EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}})); + EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}})); + EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}})); +} + /* ************************************************************************** */ // test DT TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); - // create a value - Assignment x00, x01, x10, x11; - x00[A] = 0, x00[B] = 0; - x01[A] = 0, x01[B] = 1; - x10[A] = 1, x10[B] = 0; - x11[A] = 1, x11[B] = 1; + // Create assignments using brace initialization: + Assignment x00{{A, 0}, {B, 0}}; + Assignment x01{{A, 0}, {B, 1}}; + Assignment x10{{A, 1}, {B, 0}}; + Assignment x11{{A, 1}, {B, 1}}; // empty DT empty; @@ -237,8 +279,7 @@ TEST(DecisionTree, ConvertValuesOnly) { StringBoolTree f2(f1, bool_of_int); // Check a value - Assignment x00; - x00["A"] = 0, x00["B"] = 0; + Assignment x00 {{A, 0}, {B, 0}}; EXPECT(!f2(x00)); } @@ -262,10 +303,11 @@ TEST(DecisionTree, ConvertBoth) { // Check some values Assignment