Skip to content

Commit

Permalink
remove prints #4087
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 5, 2024
1 parent 90920c6 commit 5d2f7c6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
*.csv
*.hidden
*.pkl
*.env

# don't ignore important .txt and .csv files
!requirements*
Expand Down
35 changes: 9 additions & 26 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,15 @@ def _check_atol_type(self, atol, size):

return atol

def set_up(self, model, input_list=None, t_eval=None, ics_only=False):
base_set_up_return = super().set_up(model, input_list, t_eval, ics_only)
def set_up(self, model, inputs_list=None, t_eval=None, ics_only=False):
base_set_up_return = super().set_up(model, inputs_list, t_eval, ics_only)

inputs_list = input_list or [{}]
inputs_list = inputs_list or [{}]
nparams = sum(
len(np.array(v).reshape(-1, 1)) for _, v in inputs_list[0].items()
)
ninputs = len(inputs_list)
nstates = model.y0.shape[0]
print("nstates", nstates)
print("nparams", nparams)
print("ninputs", ninputs)

# stack inputs
if inputs_list and len(inputs_list) > 0 and len(inputs_list[0]) > 0:
Expand All @@ -173,22 +170,20 @@ def set_up(self, model, input_list=None, t_eval=None, ics_only=False):
for inputs in inputs_list
for x in inputs.values()
]
inputs_sizes = [len(array) for array in arrays_to_stack]
print(arrays_to_stack)
inputs_sizes = [
len(array) for array in arrays_to_stack[: len(inputs_list[0])]
]
inputs = np.vstack(arrays_to_stack)
else:
inputs_sizes = []
inputs = np.array([[]])

def inputs_to_dict(inputs):
index = 0
new_inputs_list = []
for inputs in inputs_list:
inputs_dict = {}
for n, key in zip(inputs_sizes, inputs.keys()):
inputs_dict[key] = inputs[index : (index + n)]
for i in range(ninputs):
for n, key in zip(inputs_sizes, inputs_list[0].keys()):
inputs_list[i][key] = inputs[index : (index + n)]
index += n
new_inputs_list.append(inputs_dict)
return inputs_list

y0 = model.y0
Expand Down Expand Up @@ -229,12 +224,9 @@ def inputs_to_dict(inputs):
if model.convert_to_format == "casadi":
# TODO: do we need densify here?
rhs_algebraic = model.rhs_algebraic_eval
print("rhs_algebraic_eval", rhs_algebraic(0, y0, inputs))
else:

def resfn(t, y, inputs, ydot):
print("resfn", y, inputs)
print(model.rhs_algebraic_eval(t, y, inputs_to_dict(inputs)))
return (
model.rhs_algebraic_eval(t, y, inputs_to_dict(inputs)).flatten()
- mass_matrix @ ydot
Expand All @@ -246,7 +238,6 @@ def resfn(t, y, inputs, ydot):
# need to provide jacobian_rhs_alg - cj * mass_matrix
if model.convert_to_format == "casadi":
t_casadi = casadi.MX.sym("t")
print("nstates", nstates)
y_casadi = casadi.MX.sym("y", nstates)
cj_casadi = casadi.MX.sym("cj")
p_casadi_stacked = casadi.MX.sym("p_stacked", nparams * ninputs)
Expand All @@ -259,7 +250,6 @@ def resfn(t, y, inputs, ydot):
- cj_casadi * mass_matrix
],
)
print("jac_times_cjmass", jac_times_cjmass(0, y0, inputs, 1))

jac_times_cjmass_sparsity = jac_times_cjmass.sparsity_out(0)
jac_bw_lower = jac_times_cjmass_sparsity.bw_lower()
Expand Down Expand Up @@ -315,19 +305,15 @@ def resfn(t, y, inputs, ydot):
if sparse.issparse(jac_y0_t0):

def jacfn(t, y, inputs, cj):
print("calling jacfn", y, inputs, inputs_to_dict(inputs))
print(model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs)))
j = (
model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs))
- cj * mass_matrix
)
print("jacfn", j)
return j

else:

def jacfn(t, y, inputs, cj):
print("calling jacfn", y, inputs)
jac_eval = (
model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs))
- cj * mass_matrix
Expand Down Expand Up @@ -375,7 +361,6 @@ def get_jac_col_ptrs(self):
)
],
)
print("rootfn", rootfn(0, y0, inputs))
else:

def rootfn(t, y, inputs):
Expand All @@ -393,7 +378,6 @@ def rootfn(t, y, inputs):
rhs_ids = np.ones(model.rhs_eval(0, y0, inputs_list).shape[0])
alg_ids = np.zeros(len(y0) - len(rhs_ids))
ids = np.concatenate((rhs_ids, alg_ids))
print("ids", ids)

number_of_sensitivity_parameters = 0
if model.jacp_rhs_algebraic_eval is not None:
Expand Down Expand Up @@ -557,7 +541,6 @@ def _integrate(self, model, t_eval, inputs_list=None):
inputs = np.vstack(arrays_to_stack)
else:
inputs = np.array([[]])
print("inputs", inputs)

# do this here cause y0 is set after set_up (calc consistent conditions)
y0 = model.y0
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_ida_roberts_klu(self):
# this test implements a python version of the ida Roberts
# example provided in sundials
# see sundials ida examples pdf
for form in ["casadi", "jax"]:
for form in ["python", "casadi", "jax"]:
if form == "jax" and not pybamm.have_jax():
continue
if form == "casadi":
Expand Down Expand Up @@ -141,6 +141,7 @@ def test_model_events(self):
def test_input_params(self):
# test a mix of scalar and vector input params
for form in ["python", "casadi", "jax"]:
print("form", form)
if form == "jax" and not pybamm.have_jax():
continue
if form == "casadi":
Expand Down

0 comments on commit 5d2f7c6

Please sign in to comment.