diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index b124ea6fca..9d169abeed 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -3867,6 +3867,21 @@ def test_branch_length_empty_tree(self): assert tree.branch_length(1) == 0 assert tree.total_branch_length == 0 + @pytest.mark.parametrize("root_threshold", [1, 2, 3]) + def test_is_root(self, root_threshold): + # Make a tree with multiple roots with different numbers of samples under each + ts = tskit.Tree.generate_balanced(5).tree_sequence + ts = ts.decapitate(ts.max_root_time - 0.1) + tables = ts.dump_tables() + tables.nodes.add_row(flags=0) # Isolated non-sample + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE) # Isolated sample + ts = tables.tree_sequence() + assert {ts.first().num_samples(u) for u in ts.first().roots} == {1, 2, 3} + tree = ts.first(root_threshold=root_threshold) + roots = set(tree.roots) + for u in range(ts.num_nodes): # Will also test isolated nodes + assert tree.is_root(u) == (u in roots) + def test_is_descendant(self): def is_descendant(tree, u, v): path = [] diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 23cf681b4b..c3519cf86a 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1574,6 +1574,21 @@ def root(self): raise ValueError("More than one root exists. Use tree.roots instead") return self.left_root + def is_root(self, u) -> bool: + """ + Returns ``True`` if the specified node is a root in this tree. See + :attr:`~Tree.roots` for the definition of a root. Note that this will return + ``False`` for all other nodes, including + :ref:`isolated` non-sample nodes which are + not attached to the topology of the current tree. + + :param int u: The node of interest. + :return: ``True`` if u is a root. + """ + return ( + self.num_samples(u) >= self.root_threshold and self.parent(u) == tskit.NULL + ) + def get_index(self): # Deprecated alias for self.index return self.index