diff --git a/CHANGELOG.md b/CHANGELOG.md index a47ab6a02a..d5b88278ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- Add `Interpolant` class to interpolate experimental data (e.g. OCP curves) (#661) - Allow parameters to be set by material or by specifying a particular paper (#647) - Set relative and absolute tolerances independently in solvers (#645) diff --git a/docs/source/expression_tree/index.rst b/docs/source/expression_tree/index.rst index 8c2099ac24..2e95d07b16 100644 --- a/docs/source/expression_tree/index.rst +++ b/docs/source/expression_tree/index.rst @@ -17,4 +17,5 @@ Expression Tree broadcasts simplify functions + interpolant evaluate diff --git a/docs/source/expression_tree/interpolant.rst b/docs/source/expression_tree/interpolant.rst new file mode 100644 index 0000000000..69ad86e023 --- /dev/null +++ b/docs/source/expression_tree/interpolant.rst @@ -0,0 +1,5 @@ +Interpolant +=========== + +.. autoclass:: pybamm.Interpolant + :members: diff --git a/input/parameters/lithium-ion/cathodes/lico2_Marquis2019/lico2_data_example.csv b/input/parameters/lithium-ion/cathodes/lico2_Marquis2019/lico2_data_example.csv new file mode 100644 index 0000000000..f2f1809c79 --- /dev/null +++ b/input/parameters/lithium-ion/cathodes/lico2_Marquis2019/lico2_data_example.csv @@ -0,0 +1,50 @@ +0.000000000000000000e+00 4.714135898019971016e+00 +2.040816326530612082e-02 4.708899441575220557e+00 +4.081632653061224164e-02 4.702448345762175741e+00 +6.122448979591836593e-02 4.694558534379876136e+00 +8.163265306122448328e-02 4.684994372928071193e+00 +1.020408163265306006e-01 4.673523893805322516e+00 +1.224489795918367319e-01 4.659941254449398329e+00 +1.428571428571428492e-01 4.644096031712390271e+00 +1.632653061224489666e-01 4.625926611260677390e+00 +1.836734693877550839e-01 4.605491824833229053e+00 +2.040816326530612013e-01 4.582992038370575116e+00 +2.244897959183673186e-01 4.558769704421606228e+00 +2.448979591836734637e-01 4.533281647154224103e+00 +2.653061224489795533e-01 4.507041620859735254e+00 +2.857142857142856984e-01 4.480540404981123714e+00 +3.061224489795917880e-01 4.454158468368703439e+00 +3.265306122448979331e-01 4.428089899175588151e+00 +3.469387755102040782e-01 4.402295604083254155e+00 +3.673469387755101678e-01 4.376502631465185367e+00 +3.877551020408163129e-01 4.350272100879827519e+00 +4.081632653061224025e-01 4.323179536958428493e+00 +4.285714285714285476e-01 4.295195829713853719e+00 +4.489795918367346372e-01 4.267407675466301065e+00 +4.693877551020407823e-01 4.243081968022011985e+00 +4.897959183673469274e-01 4.220583168834260768e+00 +5.102040816326530726e-01 4.177032236370062712e+00 +5.306122448979591066e-01 4.134943568540559333e+00 +5.510204081632652517e-01 4.075402582839823928e+00 +5.714285714285713969e-01 4.055407164381796825e+00 +5.918367346938775420e-01 4.036052896449991323e+00 +6.122448979591835760e-01 4.012970397550268409e+00 +6.326530612244897211e-01 3.990385577539371287e+00 +6.530612244897958663e-01 3.970744780585252709e+00 +6.734693877551020114e-01 3.954753574690877738e+00 +6.938775510204081565e-01 3.942237451863396025e+00 +7.142857142857141906e-01 3.932683425747200534e+00 +7.346938775510203357e-01 3.925509771581312979e+00 +7.551020408163264808e-01 3.920182838859009422e+00 +7.755102040816326259e-01 3.916256861206461881e+00 +7.959183673469386600e-01 3.913378070528176877e+00 +8.163265306122448051e-01 3.911274218446639583e+00 +8.367346938775509502e-01 3.909739285381772067e+00 +8.571428571428570953e-01 3.908613829807601192e+00 +8.775510204081632404e-01 3.907726324580658162e+00 +8.979591836734692745e-01 3.906474088522892796e+00 +9.183673469387754196e-01 3.900204875423951556e+00 +9.387755102040815647e-01 3.848912814816038974e+00 +9.591836734693877098e-01 3.445226042113884724e+00 +9.795918367346938549e-01 1.687177743081021308e+00 +1.000000000000000000e+00 6.378908986260003328e-03 diff --git a/input/parameters/lithium-ion/cathodes/lico2_Marquis2019/lico2_ocp_Dualfoil1998.py b/input/parameters/lithium-ion/cathodes/lico2_Marquis2019/lico2_ocp_Dualfoil1998.py index e63d8c7074..1d6f3a306e 100644 --- a/input/parameters/lithium-ion/cathodes/lico2_Marquis2019/lico2_ocp_Dualfoil1998.py +++ b/input/parameters/lithium-ion/cathodes/lico2_Marquis2019/lico2_ocp_Dualfoil1998.py @@ -3,23 +3,23 @@ def lico2_ocp_Dualfoil1998(sto): """ - Lithium Cobalt Oxide (LiCO2) Open Circuit Potential (OCP) as a a function of the - stochiometry. The fit is taken from Dualfoil [1]. Dualfoil states that the data - was measured by Oscar Garcia 2001 using Quallion electrodes for 0.5 < sto < 0.99 - and by Marc Doyle for sto<0.4 (for unstated electrodes). We could not find any - other records of the Garcia measurements. Doyles fits can be found in his - thesis [2] but we could not find any other record of his measurments. - - References - ---------- - .. [1] http://www.cchem.berkeley.edu/jsngrp/fortran.html - .. [2] CM Doyle. Design and simulation of lithium rechargeable batteries, - 1995. - - Parameters - ---------- - sto: double - Stochiometry of material (li-fraction) + Lithium Cobalt Oxide (LiCO2) Open Circuit Potential (OCP) as a a function of the + stochiometry. The fit is taken from Dualfoil [1]. Dualfoil states that the data + was measured by Oscar Garcia 2001 using Quallion electrodes for 0.5 < sto < 0.99 + and by Marc Doyle for sto<0.4 (for unstated electrodes). We could not find any + other records of the Garcia measurements. Doyles fits can be found in his + thesis [2] but we could not find any other record of his measurments. + + References + ---------- + .. [1] http://www.cchem.berkeley.edu/jsngrp/fortran.html + .. [2] CM Doyle. Design and simulation of lithium rechargeable batteries, + 1995. + + Parameters + ---------- + sto: double + Stochiometry of material (li-fraction) """ diff --git a/pybamm/__init__.py b/pybamm/__init__.py index 4f2f3680dc..7723eb8742 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -126,6 +126,7 @@ def version(formatted=False): r_average, ) from .expression_tree.functions import * +from .expression_tree.interpolant import Interpolant from .expression_tree.parameter import Parameter, FunctionParameter from .expression_tree.broadcasts import Broadcast, PrimaryBroadcast, FullBroadcast from .expression_tree.scalar import Scalar diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 2bacdb792c..20f6ae46e4 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -18,20 +18,26 @@ class Function(pybamm.Symbol): func(child0.evaluate(t, y), child1.evaluate(t, y), etc). children : :class:`pybamm.Symbol` The children nodes to apply the function to - + derivative : str, optional + Which derivative to use when differentiating ("autograd" or "derivative"). + Default is "autograd". **Extends:** :class:`pybamm.Symbol` """ - def __init__(self, function, *children): + def __init__(self, function, *children, name=None, derivative="autograd"): - try: - name = "function ({})".format(function.__name__) - except AttributeError: - name = "function ({})".format(function.__class__) + if name is not None: + self.name = name + else: + try: + name = "function ({})".format(function.__name__) + except AttributeError: + name = "function ({})".format(function.__class__) children_list = list(children) domain = self.get_children_domains(children_list) self.function = function + self.derivative = derivative # hack to work out whether function takes any params # (signature doesn't work for numpy) @@ -73,7 +79,7 @@ def diff(self, variable): # if variable appears in the function,use autograd to differentiate # function, and apply chain rule if variable.id in [symbol.id for symbol in child.pre_order()]: - partial_derivatives[i] = child.diff(variable) * self._diff(children) + partial_derivatives[i] = self._diff(children) * child.diff(variable) # remove None entries partial_derivatives = list(filter(None, partial_derivatives)) @@ -86,7 +92,13 @@ def diff(self, variable): def _diff(self, children): """ See :meth:`pybamm.Symbol._diff()`. """ - return Function(autograd.elementwise_grad(self.function), *children) + if self.derivative == "autograd": + return Function(autograd.elementwise_grad(self.function), *children) + elif self.derivative == "derivative": + # keep using "derivative" as derivative + return pybamm.Function( + self.function.derivative(), *children, derivative="derivative" + ) def _jac(self, variable): """ See :meth:`pybamm.Symbol._jac()`. """ @@ -158,7 +170,9 @@ def _function_new_copy(self, children): : :pybamm.Function A new copy of the function """ - return pybamm.Function(self.function, *children) + return pybamm.Function( + self.function, *children, name=self.name, derivative=self.derivative + ) def _function_simplify(self, simplified_children): """ @@ -181,7 +195,12 @@ def _function_simplify(self, simplified_children): # If self.function() is a constant current then simplify to scalar return pybamm.Scalar(self.function.parameters_eval["Current [A]"]) else: - return pybamm.Function(self.function, *simplified_children) + return pybamm.Function( + self.function, + *simplified_children, + name=self.name, + derivative=self.derivative + ) class SpecificFunction(Function): @@ -233,7 +252,7 @@ def __init__(self, child): super().__init__(np.cosh, child) def _diff(self, children): - """ See :meth:`pybamm.Symbol._diff()`. """ + """ See :meth:`pybamm.Function._diff()`. """ return Sinh(children[0]) @@ -249,7 +268,7 @@ def __init__(self, child): super().__init__(np.exp, child) def _diff(self, children): - """ See :meth:`pybamm.Symbol._diff()`. """ + """ See :meth:`pybamm.Function._diff()`. """ return Exponential(children[0]) @@ -265,7 +284,7 @@ def __init__(self, child): super().__init__(np.log, child) def _diff(self, children): - """ See :meth:`pybamm.Symbol._diff()`. """ + """ See :meth:`pybamm.Function._diff()`. """ return 1 / children[0] @@ -291,7 +310,7 @@ def __init__(self, child): super().__init__(np.sin, child) def _diff(self, children): - """ See :meth:`pybamm.Symbol._diff()`. """ + """ See :meth:`pybamm.Function._diff()`. """ return Cos(children[0]) @@ -307,7 +326,7 @@ def __init__(self, child): super().__init__(np.sinh, child) def _diff(self, children): - """ See :meth:`pybamm.Symbol._diff()`. """ + """ See :meth:`pybamm.Function._diff()`. """ return Cosh(children[0]) diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py new file mode 100644 index 0000000000..5cb8a9c6ca --- /dev/null +++ b/pybamm/expression_tree/interpolant.py @@ -0,0 +1,64 @@ +# +# Interpolating class +# +import pybamm +from scipy import interpolate + + +class Interpolant(pybamm.Function): + """ + Interpolate data in 1D. + + Parameters + ---------- + data : :class:`numpy.ndarray` + Numpy array of data to use for interpolation. Must have exactly two columns (x + and y data) + child : :class:`pybamm.Symbol` + Node to use when evaluating the interpolant + name : str, optional + Name of the interpolant. Default is None, in which case the name "interpolating + function" is given. + interpolator : str, optional + Which interpolator to use ("pchip" or "cubic spline"). Note that whichever + interpolator is used must be differentiable (for ``Interpolator._diff``). + Default is "cubic spline". Note that "pchip" may give slow results. + extrapolate : bool, optional + Whether to extrapolate for points that are outside of the parametrisation + range, or return NaN (following default behaviour from scipy). Default is True. + + **Extends**: :class:`pybamm.Function` + """ + + def __init__( + self, data, child, name=None, interpolator="cubic spline", extrapolate=True + ): + if data.ndim != 2 or data.shape[1] != 2: + raise ValueError( + """ + data should have exactly two columns (x and y) but has shape {} + """.format( + data.shape + ) + ) + elif interpolator == "pchip": + interpolating_function = interpolate.PchipInterpolator( + data[:, 0], data[:, 1], extrapolate=extrapolate + ) + elif interpolator == "cubic spline": + interpolating_function = interpolate.CubicSpline( + data[:, 0], data[:, 1], extrapolate=extrapolate + ) + else: + raise ValueError("interpolator '{}' not recognised".format(interpolator)) + # Set name + if name is not None: + name = "interpolating function ({})".format(name) + else: + name = "interpolating function" + super().__init__( + interpolating_function, child, name=name, derivative="derivative" + ) + # Store information as attributes + self.interpolator = interpolator + self.extrapolate = extrapolate diff --git a/pybamm/expression_tree/simplify.py b/pybamm/expression_tree/simplify.py index 97ac0ab595..7df95240c4 100644 --- a/pybamm/expression_tree/simplify.py +++ b/pybamm/expression_tree/simplify.py @@ -16,7 +16,9 @@ def simplify_if_constant(symbol): if symbol.is_constant(): result = symbol.evaluate_ignoring_errors() if result is not None: - if isinstance(result, numbers.Number): + if isinstance(result, numbers.Number) or ( + isinstance(result, np.ndarray) and result.ndim == 0 + ): return pybamm.Scalar(result) elif isinstance(result, np.ndarray) or issparse(result): if result.ndim == 1 or result.shape[1] == 1: diff --git a/pybamm/parameters/parameter_values.py b/pybamm/parameters/parameter_values.py index cb2fec3193..9473f481c4 100644 --- a/pybamm/parameters/parameter_values.py +++ b/pybamm/parameters/parameter_values.py @@ -4,6 +4,7 @@ import pybamm import pandas as pd import os +import numpy as np class ParameterValues(dict): @@ -101,26 +102,7 @@ def update_from_chemistry(self, chemistry): os.path.join(component_path, "parameters.csv") ) # Update parameters, making sure to check any conflicts - self.update(component_params, check_conflict=True) - # Load functions if they are specified - for name, param in component_params.items(): - # Functions are flagged with the string "[function]" - if isinstance(param, str): - if param.startswith("[function]"): - self[name] = pybamm.load_function( - os.path.join(component_path, param[10:] + ".py") - ) - # Inbuilt functions are flagged with the string "[inbuilt]" - elif param.startswith("[inbuilt class]"): - # Extra set of brackets at the end makes an instance of the - # class - self[name] = getattr(pybamm, param[15:])() - # Data is flagged with the string "[data]" - # elif param.startswith("[data]"): - # TODO: implement interpolating function for data - # Anything else should be a converted to a float - else: - self[name] = float(param) + self.update(component_params, check_conflict=True, path=component_path) def read_parameters_csv(self, filename): """Reads parameters from csv file into dict. @@ -141,23 +123,45 @@ def read_parameters_csv(self, filename): df.dropna(how="all", inplace=True) return {k: v for (k, v) in zip(df["Name [units]"], df["Value"])} - def update(self, values, check_conflict=False): + def update(self, values, check_conflict=False, path=""): # check parameter values values = self.check_and_update_parameter_values(values) # update - for k, v in values.items(): + for name, value in values.items(): # check for conflicts if ( check_conflict is True - and k in self.keys() - and not (self[k] == float(v) or self[k] == v) + and name in self.keys() + and not (self[name] == float(value) or self[name] == value) ): raise ValueError( - "parameter '{}' already defined with value '{}'".format(k, self[k]) + "parameter '{}' already defined with value '{}'".format( + name, self[name] + ) ) - # if no conflicts, update + # if no conflicts, update, loading functions and data if they are specified else: - self[k] = v + # Functions are flagged with the string "[function]" + if isinstance(value, str): + if value.startswith("[function]"): + self[name] = pybamm.load_function( + os.path.join(path, value[10:] + ".py") + ) + # Inbuilt functions are flagged with the string "[inbuilt]" + elif value.startswith("[inbuilt class]"): + # Extra set of brackets at the end makes an instance of the + # class + self[name] = getattr(pybamm, value[15:])() + # Data is flagged with the string "[data]" + elif value.startswith("[data]"): + data = np.loadtxt(os.path.join(path, value[6:] + ".csv")) + # Save name and data + self[name] = (value[6:], data) + # Anything else should be a converted to a float + else: + self[name] = float(value) + else: + self[name] = value # reset processed symbols self._processed_symbols = {} @@ -392,8 +396,16 @@ def _process_symbol(self, symbol): if isinstance(function_name, pybamm.GetCurrentData): function_name.interpolate() - # Create Function object and differentiate if necessary - function = pybamm.Function(function_name, *new_children) + # Create Function or Interpolant objec + if isinstance(function_name, tuple): + # If function_name is a tuple then it should be (name, data) and we need + # to create an Interpolant + name, data = function_name + function = pybamm.Interpolant(data, *new_children, name=name) + else: + # otherwise create standard function + function = pybamm.Function(function_name, *new_children) + # Differentiate if necessary if symbol.diff_variable is None: return function else: diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py new file mode 100644 index 0000000000..84de9eef76 --- /dev/null +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -0,0 +1,83 @@ +# +# Tests for the Function classes +# +import pybamm + +import unittest +import numpy as np + + +class TestInterpolant(unittest.TestCase): + def test_errors(self): + with self.assertRaisesRegex(ValueError, "data should have exactly two columns"): + pybamm.Interpolant(np.ones(10), None) + with self.assertRaisesRegex(ValueError, "interpolator 'bla' not recognised"): + pybamm.Interpolant(np.ones((10, 2)), None, interpolator="bla") + + def test_interpolation(self): + x = np.linspace(0, 1)[:, np.newaxis] + y = pybamm.StateVector(slice(0, 2)) + # linear + linear = np.hstack([x, 2 * x]) + for interpolator in ["pchip", "cubic spline"]: + interp = pybamm.Interpolant(linear, y, interpolator=interpolator) + np.testing.assert_array_almost_equal( + interp.evaluate(y=np.array([0.397, 1.5]))[:, 0], np.array([0.794, 3]) + ) + # square + square = np.hstack([x, x ** 2]) + y = pybamm.StateVector(slice(0, 1)) + for interpolator in ["pchip", "cubic spline"]: + interp = pybamm.Interpolant(square, y, interpolator=interpolator) + np.testing.assert_array_almost_equal( + interp.evaluate(y=np.array([0.397]))[:, 0], np.array([0.397 ** 2]) + ) + + # with extrapolation set to False + for interpolator in ["pchip", "cubic spline"]: + interp = pybamm.Interpolant( + square, y, interpolator=interpolator, extrapolate=False + ) + np.testing.assert_array_equal( + interp.evaluate(y=np.array([2]))[:, 0], np.array([np.nan]) + ) + + def test_name(self): + a = pybamm.Symbol("a") + x = np.linspace(0, 1)[:, np.newaxis] + interp = pybamm.Interpolant(np.hstack([x, x]), a, "name") + self.assertEqual(interp.name, "interpolating function (name)") + + def test_diff(self): + x = np.linspace(0, 1)[:, np.newaxis] + y = pybamm.StateVector(slice(0, 2)) + # linear (derivative should be 2) + linear = np.hstack([x, 2 * x]) + for interpolator in ["pchip", "cubic spline"]: + interp_diff = pybamm.Interpolant(linear, y, interpolator=interpolator).diff( + y + ) + np.testing.assert_array_almost_equal( + interp_diff.evaluate(y=np.array([0.397, 1.5]))[:, 0], np.array([2, 2]) + ) + # square (derivative should be 2*x) + square = np.hstack([x, x ** 2]) + for interpolator in ["pchip", "cubic spline"]: + interp_diff = pybamm.Interpolant(square, y, interpolator=interpolator).diff( + y + ) + np.testing.assert_array_almost_equal( + interp_diff.evaluate(y=np.array([0.397, 0.806]))[:, 0], + np.array([0.794, 1.612]), + decimal=3, + ) + + +if __name__ == "__main__": + print("Add -v for more debug output") + import sys + + if "-v" in sys.argv: + debug = True + pybamm.settings.debug_mode = True + unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_values.py b/tests/unit/test_parameters/test_parameter_values.py index 08311bcd6f..db92388f93 100644 --- a/tests/unit/test_parameters/test_parameter_values.py +++ b/tests/unit/test_parameters/test_parameter_values.py @@ -2,10 +2,10 @@ # Tests for the Base Parameter Values class # import pybamm - -import unittest +import os import numpy as np +import unittest import tests.shared as shared @@ -25,7 +25,7 @@ def test_init(self): values="input/parameters/lithium-ion/cathodes/lico2_Marquis2019/" + "parameters.csv" ) - self.assertEqual(param["Reference temperature [K]"], "298.15") + self.assertEqual(param["Reference temperature [K]"], 298.15) # values vs chemistry with self.assertRaisesRegex( @@ -294,6 +294,61 @@ def D(a, b): self.assertIsInstance(processed_func, pybamm.Function) self.assertEqual(processed_func.evaluate(), 3) + def test_process_interpolant(self): + x = np.linspace(0, 10)[:, np.newaxis] + data = np.hstack([x, 2 * x]) + parameter_values = pybamm.ParameterValues( + {"a": 3.01, "Diffusivity": ("times two", data)} + ) + + a = pybamm.Parameter("a") + func = pybamm.FunctionParameter("Diffusivity", a) + + processed_func = parameter_values.process_symbol(func) + self.assertIsInstance(processed_func, pybamm.Interpolant) + self.assertEqual(processed_func.evaluate(), 6.02) + + # process differentiated function parameter + diff_func = func.diff(a) + processed_diff_func = parameter_values.process_symbol(diff_func) + self.assertEqual(processed_diff_func.evaluate(), 2) + + def test_interpolant_against_function(self): + parameter_values = pybamm.ParameterValues({"a": 0.6}) + parameter_values.update( + { + "function": "[function]lico2_ocp_Dualfoil1998", + "interpolation": "[data]lico2_data_example", + }, + path=os.path.join( + pybamm.root_dir(), + "input", + "parameters", + "lithium-ion", + "cathodes", + "lico2_Marquis2019", + ), + ) + + a = pybamm.Parameter("a") + func = pybamm.FunctionParameter("function", a) + interp = pybamm.FunctionParameter("interpolation", a) + + processed_func = parameter_values.process_symbol(func) + processed_interp = parameter_values.process_symbol(interp) + np.testing.assert_array_almost_equal( + processed_func.evaluate(), processed_interp.evaluate(), decimal=4 + ) + + # process differentiated function parameter + diff_func = func.diff(a) + diff_interp = interp.diff(a) + processed_diff_func = parameter_values.process_symbol(diff_func) + processed_diff_interp = parameter_values.process_symbol(diff_interp) + np.testing.assert_array_almost_equal( + processed_diff_func.evaluate(), processed_diff_interp.evaluate(), decimal=2 + ) + def test_process_complex_expression(self): var1 = pybamm.Variable("var1") var2 = pybamm.Variable("var2")