From c7d7ce24837b96bd400d42138d5b0e73015b6dcc Mon Sep 17 00:00:00 2001 From: Ao Zhang Date: Mon, 24 Oct 2022 22:47:10 +0100 Subject: [PATCH] Add unweighted_rf_distance and symmetric_distance function and tests --- python/tests/test_distance_metrics.py | 203 ++++++++++++++++++++++++++ python/tskit/trees.py | 46 ++++++ 2 files changed, 249 insertions(+) create mode 100644 python/tests/test_distance_metrics.py diff --git a/python/tests/test_distance_metrics.py b/python/tests/test_distance_metrics.py new file mode 100644 index 0000000000..ae34fe1b20 --- /dev/null +++ b/python/tests/test_distance_metrics.py @@ -0,0 +1,203 @@ +# MIT License +# +# Copyright (c) 2022 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Tests for tree distance metrics. +""" +import pytest + +import tests +import tskit + + +class TestTreeSameSamples: + # Tree1 + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + # + # Tree2 + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ ┃ 5 ┊ + # ┊ ┃ ┏━┻┓ ┊ + # 1.00┊ ┃ ┃ 4 ┊ + # ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + + @tests.cached_example + def tree(self): + return tskit.Tree.generate_balanced(4) + + @tests.cached_example + def tree_other(self): + return tskit.Tree.generate_comb(4) + + def test_symmetric_distance(self): + assert self.tree().symmetric_distance(self.tree_other()) == 2 + + def test_unweighted_robinson_foulds(self): + assert self.tree().unweighted_rf_distance(self.tree_other()) == 2 + + +class TestTreeDifferentSamples: + # Tree1 + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + # + # Tree2 + # 4.00┊ 8 ┊ + # ┊ ┏━┻━┓ ┊ + # 3.00┊ ┃ 7 ┊ + # ┊ ┃ ┏━┻━┓ ┊ + # 2.00┊ ┃ ┃ 6 ┊ + # ┊ ┃ ┃ ┏━┻┓ ┊ + # 1.00┊ ┃ ┃ ┃ 5 ┊ + # ┊ ┃ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 4 ┊ + # 0 1 + + @tests.cached_example + def tree(self): + return tskit.Tree.generate_balanced(4) + + @tests.cached_example + def tree_other(self): + return tskit.Tree.generate_comb(5) + + def test_symmetric_distance(self): + assert self.tree().symmetric_distance(self.tree_other()) == 8 + + def test_unweighted_robinson_foulds(self): + assert self.tree().unweighted_rf_distance(self.tree_other()) == 8 + + +class TestTreeMultiRoots: + # Tree1 + # 4.00┊ 15 ┊ + # ┊ ┏━━━┻━━━┓ ┊ + # 3.00┊ ┃ 14 ┊ + # ┊ ┃ ┏━┻━┓ ┊ + # 2.00┊ 12 ┃ 13 ┊ + # ┊ ┏━┻━┓ ┃ ┏┻┓ ┊ + # 1.00┊ 9 10 ┃ ┃ 11 ┊ + # ┊ ┏┻┓ ┏┻┓ ┏┻┓ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 4 5 6 7 8 ┊ + # 0 1 + # + # Tree2 + # 3.00┊ 15 ┊ + # ┊ ┏━━┻━┓ ┊ + # 2.00┊ 11 ┃ 14 ┊ + # ┊ ┏━┻━┓ ┃ ┏━┻┓ ┊ + # 1.00┊ 9 10 12 ┃ 13 ┊ + # ┊ ┏┻┓ ┏┻┓ ┏┻┓ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 4 5 6 7 8 ┊ + # 0 1 + + @tests.cached_example + def tree(self): + return tskit.Tree.generate_balanced(9) + + @tests.cached_example + def tree_other(self): + tables = tskit.Tree.generate_balanced(9, arity=2).tree_sequence.dump_tables() + edges = tables.edges.copy() + tables.edges.clear() + for edge in edges: + if edge.parent != 16: + tables.edges.append(edge) + return tables.tree_sequence().first() + + def test_symmetric_distance(self): + with pytest.raises(ValueError): + self.tree().symmetric_distance(self.tree_other()) + + def test_unweighted_robinson_foulds(self): + with pytest.raises(ValueError): + self.tree().unweighted_rf_distance(self.tree_other()) + + +class TestEmpty: + @tests.cached_example + def tree(self): + tables = tskit.TableCollection(1) + return tables.tree_sequence().first() + + @tests.cached_example + def tree_other(self): + tables = tskit.TableCollection(1) + return tables.tree_sequence().first() + + def test_symmetric_distance(self): + with pytest.raises(ValueError): + self.tree().symmetric_distance(self.tree_other()) + + def test_unweighted_robinson_foulds(self): + with pytest.raises(ValueError): + self.tree().unweighted_rf_distance(self.tree_other()) + + +class TestTreeInNullState: + @tests.cached_example + def tsk_tree1(self): + tree = tskit.Tree.generate_comb(5) + tree.clear() + return tree + + @tests.cached_example + def tree_other(self): + tree = tskit.Tree.generate_comb(5) + tree.clear() + return tree + + def test_symmetric_distance(self): + with pytest.raises(ValueError): + self.tsk_tree1().symmetric_distance(self.tree_other()) + + def test_unweighted_robinson_foulds(self): + with pytest.raises(ValueError): + self.tsk_tree1().unweighted_rf_distance(self.tree_other()) + + +class TestAllRootsN5: + @tests.cached_example + def tree(self): + tables = tskit.TableCollection(1) + for _ in range(5): + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + return tables.tree_sequence().first() + + def test_symmetric_distance(self): + with pytest.raises(ValueError): + self.tree().symmetric_distance(self.tree()) + + def test_unweighted_robinson_foulds(self): + with pytest.raises(ValueError): + self.tree().unweighted_rf_distance(self.tree()) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 7b58c37917..70a539f4bc 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -2840,6 +2840,52 @@ def kc_distance(self, other, lambda_=0.0): """ return self._ll_tree.get_kc_distance(other._ll_tree, lambda_) + def _get_sample_sets(self): + ret = collections.defaultdict(set) + for u in self.samples(): + ret[u].add(u) + for u in self.nodes(order="postorder"): + for v in self.children(u): + ret[u] |= ret[v] + return ret + + def unweighted_rf_distance(self, other): + """ + Returns the unweighted Robinson-Foulds distance between the specified pair of + trees. The unweighted RF distance is the number of leaf pairs that are + in different locations in the two trees. + + :param Tree other: The other tree to compare to. + :return: The computed unweighted RF distance between this tree and other. + :rtype: int + """ + return self.symmetric_distance(other) + + def symmetric_distance(self, other): + """ + Returns the unweighted Robinson-Foulds distance (symmetric difference) + between the specified pair of trees. The symmetric_distance is the number + of leaf pairs that are in different locations in the two trees. + + .. seealso:: + See `Robinson & Foulds (1981) + `_ for more details. + + :param Tree other: The other tree to compare to. + :return: The computed symmetric difference between this tree and other. + :rtype : int + """ + if self.num_roots != 1 or other.num_roots != 1: + raise ValueError("Trees must have a single root") + + b1 = self._get_sample_sets() + b2 = other._get_sample_sets() + + s1 = {frozenset(x) for x in b1.values()} + s2 = {frozenset(x) for x in b2.values()} + + return len(s1.symmetric_difference(s2)) + def path_length(self, u, v): """ Returns the path length between two nodes