Skip to content

Commit

Permalink
Allow any set of sample nodes to be passed
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Oct 31, 2022
1 parent 8fbe802 commit ac20541
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
9 changes: 5 additions & 4 deletions python/tests/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
keep_unary=False,
keep_unary_in_individuals=False,
keep_input_roots=False,
filter_nodes=True,
filter_nodes=True, # If this is False, the order in `sample` is ignored
):
self.ts = ts
self.n = len(sample)
Expand Down Expand Up @@ -148,10 +148,11 @@ def __init__(
output_id = self.record_node(sample_id, is_sample=True)
self.add_ancestry(sample_id, 0, self.sequence_length, output_id)
else:
assert list(sample) == list(ts.samples())
sample = set(sample)
for node in ts.nodes():
self.record_node(node.id, node.is_sample())
if node.is_sample():
is_sample = node.id in sample
self.record_node(node.id, is_sample=is_sample)
if is_sample:
self.add_ancestry(node.id, 0, self.sequence_length, node.id)

self.position_lookup = None
Expand Down
48 changes: 48 additions & 0 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -5851,12 +5851,35 @@ def verify_nodes_unchanged(self, ts_in, resample_size=None):
)
assert np.array_equal(n_map, np.arange(ts.num_nodes, dtype=n_map.dtype))
for n1, n2 in zip(ts.nodes(), filtered.nodes()):
# Ignore the tskit.NODE_IS_SAMPLE flag which can be changed by simplify
n1 = n1.replace(flags=n1.flags | tskit.NODE_IS_SAMPLE)
n2 = n2.replace(flags=n2.flags | tskit.NODE_IS_SAMPLE)
assert n1 == n2

# Check that edges are identical to the normal simplify(),
# with the normal "simplify" having altered IDs
simplified, node_map = ts.simplify(samples=samples, map_nodes=True)
simplified_edges = {e for e in simplified.tables.edges}
filtered_edges = {
e.replace(parent=node_map[e.parent], child=node_map[e.child])
for e in filtered.tables.edges
}
assert filtered_edges == simplified_edges

def test_empty(self):
ts = tskit.TableCollection(1).tree_sequence()
self.verify_nodes_unchanged(ts)

def test_all_samples(self):
ts = tskit.Tree.generate_comb(5).tree_sequence
tables = ts.dump_tables()
flags = tables.nodes.flags
flags |= tskit.NODE_IS_SAMPLE
tables.nodes.flags = flags
ts = tables.tree_sequence()
assert ts.num_samples == ts.num_nodes
self.verify_nodes_unchanged(ts)

@pytest.mark.parametrize("resample_size", [None, 4])
def test_no_topology(self, resample_size):
ts = tskit.Tree.generate_comb(5).tree_sequence
Expand All @@ -5871,12 +5894,37 @@ def test_stick_tree(self, resample_size):
assert ts.first().parent(0) != tskit.NULL
self.verify_nodes_unchanged(ts, resample_size=resample_size)

# switch to an internal sample
tables = ts.dump_tables()
flags = tables.nodes.flags
flags[0] = 0
flags[1] = tskit.NODE_IS_SAMPLE
tables.nodes.flags = flags
self.verify_nodes_unchanged(tables.tree_sequence(), resample_size=resample_size)

@pytest.mark.parametrize("resample_size", [None, 4])
def test_internal_samples(self, resample_size):
ts = tskit.Tree.generate_comb(4).tree_sequence
tables = ts.dump_tables()
flags = tables.nodes.flags
flags ^= tskit.NODE_IS_SAMPLE
tables.nodes.flags = flags
ts = tables.tree_sequence()
assert np.all(ts.samples() >= ts.num_samples)
self.verify_nodes_unchanged(ts, resample_size=resample_size)

@pytest.mark.parametrize("resample_size", [None, 4])
def test_blank_flanks(self, resample_size):
ts = tskit.Tree.generate_comb(4).tree_sequence
ts = ts.keep_intervals([[0.25, 0.75]], simplify=False)
self.verify_nodes_unchanged(ts, resample_size=resample_size)

@pytest.mark.parametrize("resample_size", [None, 4])
def test_multiroot(self, resample_size):
ts = tskit.Tree.generate_balanced(6).tree_sequence
ts = ts.decapitate(2.5)
self.verify_nodes_unchanged(ts, resample_size=resample_size)

@pytest.mark.parametrize("resample_size", [None, 10])
def test_with_metadata(self, ts_fixture_for_simplify, resample_size):
assert ts_fixture_for_simplify.num_nodes > 10
Expand Down

0 comments on commit ac20541

Please sign in to comment.