From 5de977cc203c5efe4fb3b401723d87436c6fe1b0 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 23 Dec 2022 19:48:23 +0000 Subject: [PATCH] Implement is_root Fixes #2620 --- python/CHANGELOG.rst | 3 +++ python/tests/test_highlevel.py | 15 +++++++++++++++ python/tskit/trees.py | 17 +++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 268680a378..16e2ae4d4c 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,6 +4,9 @@ **Features** +- A new ``Tree.is_root`` method avoids the need to to search the potentially + large list of ``Tree.roots`` (:user:`hyanwong`, :pr:`2669`, :issue:`2620`) + - The ``TreeSequence`` object now has the attributes ``min_time`` and ``max_time``, which are the minimum and maximum among the node times and mutation times, respectively. (:user:`szhan`, :pr:`2612`, :issue:`2271`) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index b124ea6fca..9d169abeed 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -3867,6 +3867,21 @@ def test_branch_length_empty_tree(self): assert tree.branch_length(1) == 0 assert tree.total_branch_length == 0 + @pytest.mark.parametrize("root_threshold", [1, 2, 3]) + def test_is_root(self, root_threshold): + # Make a tree with multiple roots with different numbers of samples under each + ts = tskit.Tree.generate_balanced(5).tree_sequence + ts = ts.decapitate(ts.max_root_time - 0.1) + tables = ts.dump_tables() + tables.nodes.add_row(flags=0) # Isolated non-sample + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE) # Isolated sample + ts = tables.tree_sequence() + assert {ts.first().num_samples(u) for u in ts.first().roots} == {1, 2, 3} + tree = ts.first(root_threshold=root_threshold) + roots = set(tree.roots) + for u in range(ts.num_nodes): # Will also test isolated nodes + assert tree.is_root(u) == (u in roots) + def test_is_descendant(self): def is_descendant(tree, u, v): path = [] diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 23cf681b4b..becdf504a4 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1574,6 +1574,23 @@ def root(self): raise ValueError("More than one root exists. Use tree.roots instead") return self.left_root + def is_root(self, u) -> bool: + """ + Returns ``True`` if the specified node is a root in this tree (see + :attr:`~Tree.roots` for the definition of a root). This is exactly equivalent to + finding the node ID in :attr:`~Tree.roots`, but is more efficient for trees + with large numbers of roots, such as in regions with extensive + :ref:`sec_data_model_missing_data`. Note that ``False`` is returned for all + other nodes, including :ref:`isolated` + non-sample nodes which are not found in the topology of the current tree. + + :param int u: The node of interest. + :return: ``True`` if u is a root. + """ + return ( + self.num_samples(u) >= self.root_threshold and self.parent(u) == tskit.NULL + ) + def get_index(self): # Deprecated alias for self.index return self.index