Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#652 cache evaluation array #653

Merged
merged 1 commit into from
Oct 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,21 @@ class StateVector(pybamm.Symbol):
list of domains the parameter is valid over, defaults to empty list
auxiliary_domains : dict of str, optional
dictionary of auxiliary domains
evaluation_array : list, optional
List of boolean arrays representing slices. Default is None, in which case the
evaluation_array is computed from y_slices.

*Extends:* :class:`Array`
"""

def __init__(self, *y_slices, name=None, domain=None, auxiliary_domains=None):
def __init__(
self,
*y_slices,
name=None,
domain=None,
auxiliary_domains=None,
evaluation_array=None,
):
for y_slice in y_slices:
if not isinstance(y_slice, slice):
raise TypeError("all y_slices must be slice objects")
Expand All @@ -52,7 +62,7 @@ def __init__(self, *y_slices, name=None, domain=None, auxiliary_domains=None):
self._y_slices = y_slices
self._first_point = y_slices[0].start
self._last_point = y_slices[-1].stop
self.set_evaluation_array(y_slices)
self.set_evaluation_array(y_slices, evaluation_array)
super().__init__(name=name, domain=domain, auxiliary_domains=auxiliary_domains)

@property
Expand All @@ -76,12 +86,15 @@ def evaluation_array(self):
def size(self):
return self.evaluation_array.count(True)

def set_evaluation_array(self, y_slices):
def set_evaluation_array(self, y_slices, evaluation_array):
"Set evaluation array using slices"
array = np.zeros(y_slices[-1].stop)
for y_slice in y_slices:
array[y_slice] = True
self._evaluation_array = [bool(x) for x in array]
if evaluation_array is not None and pybamm.settings.debug_mode is False:
self._evaluation_array = evaluation_array
else:
array = np.zeros(y_slices[-1].stop)
for y_slice in y_slices:
array[y_slice] = True
self._evaluation_array = [bool(x) for x in array]

def set_id(self):
""" See :meth:`pybamm.Symbol.set_id()` """
Expand Down Expand Up @@ -156,7 +169,8 @@ def new_copy(self):
*self.y_slices,
name=self.name,
domain=self.domain,
auxiliary_domains=self.auxiliary_domains
auxiliary_domains=self.auxiliary_domains,
evaluation_array=self.evaluation_array,
)

def evaluate_for_shape(self):
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_expression_tree/test_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ def test_name(self):
)
self.assertEqual(sv.name, "y[0:10,20:30,...,60:70]")

def test_pass_evaluation_array(self):
# Turn off debug mode for this test
original_debug_mode = pybamm.settings.debug_mode
pybamm.settings.debug_mode = False
# Test that evaluation array gets passed down (doesn't have to be the correct
# array for this test)
array = np.array([1, 2, 3, 4, 5])
sv = pybamm.StateVector(slice(0, 10), evaluation_array=array)
np.testing.assert_array_equal(sv.evaluation_array, array)
# Turn debug mode back to what is was before
pybamm.settings.debug_mode = original_debug_mode

def test_failure(self):
with self.assertRaisesRegex(TypeError, "all y_slices must be slice objects"):
pybamm.StateVector(slice(0, 10), 1)
Expand Down