Skip to content

Commit

Permalink
expand unit tests, fixes #11
Browse files Browse the repository at this point in the history
  • Loading branch information
gregcaporaso committed Jul 2, 2024
1 parent cbc0a24 commit 65b652c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
5 changes: 0 additions & 5 deletions q2_boots/_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@
def _bootstrap_iteration(table: biom.Table, sampling_depth: int) -> biom.Table:
table = table.filter(lambda v, i, m: v.sum() >= sampling_depth,
inplace=False, axis='sample')

table = table.subsample(sampling_depth, axis='sample', by_id=False,
with_replacement=True)

if table.is_empty():
return ValueError('The output table contains no samples or features.'
'Verify your table is valid and that you provided a '
'shallow enough samplign depth')

return table

Expand Down
37 changes: 34 additions & 3 deletions q2_boots/tests/test_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,43 @@

class TestBootstrapIteration(TestCase):

def test_bootstrap_iteration(self):
def test_bootstrap_iteration_filters_samples(self):
t = Table(np.array([[0, 1, 3], [1, 1, 2]]),
['O1', 'O2'],
['S1', 'S2', 'S3'])
a = _bootstrap_iteration(t, 2)
self.assertEqual(a.shape, (2, 2))

observed = _bootstrap_iteration(t, 6)
self.assertTrue(observed.is_empty())

observed = _bootstrap_iteration(t, 5)
self.assertEqual(list(observed.ids(axis='sample')), ['S3'])

observed = _bootstrap_iteration(t, 2)
self.assertEqual(list(observed.ids(axis='sample')), ['S2', 'S3'])

observed = _bootstrap_iteration(t, 1)
self.assertEqual(list(observed.ids(axis='sample')), ['S1', 'S2', 'S3'])

def test_bootstrap_iteration_obtains_expected_counts(self):
t = Table(np.array([[0, 10, 30], [1, 10, 20]]),
['O1', 'O2'],
['S1', 'S2', 'S3'])

observed = _bootstrap_iteration(t, 1)
self.assertEqual(list(observed.sum(axis="sample")), [1., 1., 1.])

observed = _bootstrap_iteration(t, 10)
self.assertEqual(list(observed.sum(axis="sample")), [10., 10.])

observed = _bootstrap_iteration(t, 19)
self.assertEqual(list(observed.sum(axis="sample")), [19., 19.])

observed = _bootstrap_iteration(t, 25)
self.assertEqual(list(observed.sum(axis="sample")), [25.])

observed = _bootstrap_iteration(t, 49)
self.assertEqual(list(observed.sum(axis="sample")), [49.])



if __name__ == "__main__":
Expand Down

0 comments on commit 65b652c

Please sign in to comment.