Skip to content
This repository has been archived by the owner on Jul 10, 2021. It is now read-only.

Commit

Permalink
Tests for pull request #158.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjc committed Dec 31, 2015
1 parent 791c43c commit 7831214
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions sknn/tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,37 @@ def terminate(**_):
assert_equals(self.counter, 1)


class TestBatchSize(unittest.TestCase):

def setUp(self):
self.batch_count = 0
self.nn = MLP(
layers=[L("Rectifier")],
learning_rate=0.001, n_iter=1,
callback={'on_batch_start': self.on_batch_start})

def on_batch_start(self, **args):
self.batch_count += 1

def test_BatchSizeLargerThanInput(self):
self.nn.batch_size = 32
a_in, a_out = numpy.zeros((8,16)), numpy.ones((8,4))
self.nn._fit(a_in, a_out)
assert_equals(1, self.batch_count)

def test_BatchSizeSmallerThanInput(self):
self.nn.batch_size = 4
a_in, a_out = numpy.ones((8,16)), numpy.zeros((8,4))
self.nn._fit(a_in, a_out)
assert_equals(2, self.batch_count)

def test_BatchSizeNonMultiple(self):
self.nn.batch_size = 4
a_in, a_out = numpy.zeros((9,16)), numpy.ones((9,4))
self.nn._fit(a_in, a_out)
assert_equals(3, self.batch_count)


class TestCustomLogging(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 7831214

Please sign in to comment.