From 51981c4843eadf4dd304accccb91a4116c77817e Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Fri, 10 May 2024 20:08:42 +0100 Subject: [PATCH] bug: use casadi MX.interpn_linear function instead of plugin #3783 (#4077) * bug: use casadi MX.interpn_linear function instead of plugin #3783 * bug: fix for 2d and 3d linear interpolant #3783 * cover cubic interpolation in 2d #3783 * #3783 add to changelog --------- Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> --- CHANGELOG.md | 1 + pybamm/expression_tree/interpolant.py | 3 +- .../operations/convert_to_casadi.py | 32 ++++++++++++++----- .../test_operations/test_convert_to_casadi.py | 2 +- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a125d704f0..96d963e2b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ - Updated `plot_voltage_components.py` to support both `Simulation` and `Solution` objects. Added new methods in both `Simulation` and `Solution` classes for allow the syntax `simulation.plot_voltage_components` and `solution.plot_voltage_components`. Updated `test_plot_voltage_components.py` to reflect these changes ([#3723](https://github.com/pybamm-team/PyBaMM/pull/3723)). - The SEI thickness decreased at some intervals when the 'electron-migration limited' model was used. It has been corrected ([#3622](https://github.com/pybamm-team/PyBaMM/pull/3622)) - Allow input parameters in ESOH model ([#3921](https://github.com/pybamm-team/PyBaMM/pull/3921)) +- Use casadi MX.interpn_linear function instead of plugin to fix casadi_interpolant_linear.dll not found on Windows ([#4077](https://github.com/pybamm-team/PyBaMM/pull/4077)) ## Optimizations diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 10881d3084..dd0980fb46 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -126,9 +126,10 @@ def __init__( fill_value_1 = "extrapolate" interpolating_function = interpolate.interp1d( x1, - y.T, + y, bounds_error=False, fill_value=fill_value_1, + axis=0, ) elif interpolator == "cubic": interpolating_function = interpolate.CubicSpline( diff --git a/pybamm/expression_tree/operations/convert_to_casadi.py b/pybamm/expression_tree/operations/convert_to_casadi.py index 196da9dec9..274fd95154 100644 --- a/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/pybamm/expression_tree/operations/convert_to_casadi.py @@ -157,15 +157,31 @@ def _convert(self, symbol, t, y, y_dot, inputs): ) if len(converted_children) == 1: - return casadi.interpolant( - "LUT", solver, symbol.x, symbol.y.flatten() - )(*converted_children) + if solver == "linear": + test = casadi.MX.interpn_linear( + symbol.x, symbol.y.flatten(), converted_children + ) + if test.shape[0] == 1 and test.shape[1] > 1: + # for some reason, pybamm.Interpolant always returns a column vector, so match that + test = test.T + return test + else: + return casadi.interpolant( + "LUT", solver, symbol.x, symbol.y.flatten() + )(*converted_children) elif len(converted_children) in [2, 3]: - LUT = casadi.interpolant( - "LUT", solver, symbol.x, symbol.y.ravel(order="F") - ) - res = LUT(casadi.hcat(converted_children).T).T - return res + if solver == "linear": + return casadi.MX.interpn_linear( + symbol.x, + symbol.y.ravel(order="F"), + converted_children, + ) + else: + LUT = casadi.interpolant( + "LUT", solver, symbol.x, symbol.y.ravel(order="F") + ) + res = LUT(casadi.hcat(converted_children).T).T + return res else: # pragma: no cover raise ValueError( f"Invalid converted_children count: {len(converted_children)}" diff --git a/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py b/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py index db50ac8c92..2b9aa08479 100644 --- a/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py +++ b/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py @@ -219,7 +219,7 @@ def test_interpolation_2d(self): # linear y_test = np.array([0.4, 0.6]) Y = (2 * x).sum(axis=1).reshape(*[len(el) for el in x_]) - for interpolator in ["linear"]: + for interpolator in ["linear", "cubic"]: interp = pybamm.Interpolant(x_, Y, y, interpolator=interpolator) interp_casadi = interp.to_casadi(y=casadi_y) f = casadi.Function("f", [casadi_y], [interp_casadi])