Skip to content
This repository has been archived by the owner on Jul 10, 2021. It is now read-only.

Commit

Permalink
Support for numpy memory mapped arrays and tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjc committed Feb 26, 2016
1 parent 1d9321b commit 39b40d8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sknn/backend/lasagne/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def cast(array, indices):
return None

array = array[indices]
if type(array) != numpy.ndarray:
if not isinstance(array, numpy.ndarray):
assert hasattr(array, 'todense'), "Unknown data format and cannot convert to numpy.ndarray."
array = array.todense()
if array.dtype != theano.config.floatX:
array = array.astype(theano.config.floatX)
Expand Down
30 changes: 30 additions & 0 deletions sknn/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import unittest
from nose.tools import (assert_is_not_none, assert_raises, assert_equal, assert_true)

import os
import random
import shutil
import tempfile

import theano
import numpy
Expand Down Expand Up @@ -74,6 +77,33 @@ def test_Predict32(self):
assert_equal(yp.dtype, numpy.float32)


class TestMemoryMap(unittest.TestCase):

def setUp(self):
self.nn = MLP(layers=[L("Linear", units=3)], n_iter=1)
self.directory = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.directory)

def make(self, name, shape, dtype):
filename = os.path.join(self.directory, name)
return numpy.memmap(filename, dtype=dtype, mode='w+', shape=shape)

def test_FitAllTypes(self):
for t in ['float32', 'float64']:
theano.config.floatX = t
X = self.make('X', (12, 3), dtype=t)
y = self.make('y', (12, 3), dtype=t)
self.nn._fit(X, y)

def test_PredictAllTypes(self):
for t in ['float32', 'float64']:
theano.config.floatX = t
X = self.make('X', (12, 3), dtype=t)
yp = self.nn._predict(X)


class TestConvolution(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 39b40d8

Please sign in to comment.