From 5abff4fb49b8d00b8903e0b745f805f8169ff6c0 Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Fri, 13 Sep 2024 15:14:17 +0200 Subject: [PATCH] Reapply "Newton error propagation based convergence check." (#1419) * Set `dt` if needed. * Use `(X - X_old) / dt = dX`. Instead of `(X - X_old) = dt * dX`. The implemented scaling is the one implemented in NOCMODL. It should prevent the residual from growing linearly with `dt`. * Reduce MAX_ITER to 50 and adjust tolerance. * Newton error propagation based convergence check. * Add tests. Includes a test that solves the pump equation ~ X + Y <-> Z with extreme compartment sizes and steady state calculation. * Fix codegen issues specific to NVHPC. --- src/codegen/codegen_neuron_cpp_visitor.cpp | 14 +- src/solver/newton/newton.hpp | 39 +++- src/visitors/sympy_solver_visitor.cpp | 10 +- test/unit/visitor/sympy_solver.cpp | 230 ++++++++++---------- test/usecases/CMakeLists.txt | 1 + test/usecases/steady_state/minipump.mod | 43 ++++ test/usecases/steady_state/test_minipump.py | 86 ++++++++ 7 files changed, 293 insertions(+), 130 deletions(-) create mode 100644 test/usecases/steady_state/minipump.mod create mode 100644 test/usecases/steady_state/test_minipump.py diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index e4759b446..770bc534d 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -1526,9 +1526,21 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) { print_rename_state_vars(); + if (!info.changed_dt.empty()) { + printer->fmt_line("double _save_prev_dt = {};", + get_variable_name(naming::NTHREAD_DT_VARIABLE)); + printer->fmt_line("{} = {};", + get_variable_name(naming::NTHREAD_DT_VARIABLE), + info.changed_dt); + } + print_initial_block(info.initial_node); - printer->pop_block(); + if (!info.changed_dt.empty()) { + printer->fmt_line("{} = _save_prev_dt;", get_variable_name(naming::NTHREAD_DT_VARIABLE)); + } + + printer->pop_block(); printer->pop_block(); } diff --git a/src/solver/newton/newton.hpp b/src/solver/newton/newton.hpp index 77d6470e1..bd627d0db 100644 --- a/src/solver/newton/newton.hpp +++ b/src/solver/newton/newton.hpp @@ -34,8 +34,33 @@ namespace newton { * @{ */ -static constexpr int MAX_ITER = 1e3; -static constexpr double EPS = 1e-12; +static constexpr int MAX_ITER = 50; +static constexpr double EPS = 1e-13; + +template +EIGEN_DEVICE_FUNC bool is_converged(const Eigen::Matrix& X, + const Eigen::Matrix& J, + const Eigen::Matrix& F, + double eps) { + bool converged = true; + double square_eps = eps * eps; + for (Eigen::Index i = 0; i < N; ++i) { + double square_error = 0.0; + for (Eigen::Index j = 0; j < N; ++j) { + double JX = J(i, j) * X(j); + square_error += JX * JX; + } + + if (F(i) * F(i) > square_eps * square_error) { + converged = false; +// The NVHPC is buggy and wont allow us to short-circuit. +#ifndef __NVCOMPILER + return converged; +#endif + } + } + return converged; +} /** * \brief Newton method with user-provided Jacobian @@ -58,17 +83,14 @@ EIGEN_DEVICE_FUNC int newton_solver(Eigen::Matrix& X, int max_iter = MAX_ITER) { // Vector to store result of function F(X): Eigen::Matrix F; - // Matrix to store jacobian of F(X): + // Matrix to store Jacobian of F(X): Eigen::Matrix J; // Solver iteration count: int iter = -1; while (++iter < max_iter) { // calculate F, J from X using user-supplied functor functor(X, F, J); - // get error norm: here we use sqrt(|F|^2) - double error = F.norm(); - if (error < eps) { - // we have converged: return iteration count + if (is_converged(X, J, F, eps)) { return iter; } // In Eigen the default storage order is ColMajor. @@ -109,8 +131,7 @@ EIGEN_DEVICE_FUNC int newton_solver_small_N(Eigen::Matrix& X, int iter = -1; while (++iter < max_iter) { functor(X, F, J); - double error = F.norm(); - if (error < eps) { + if (is_converged(X, J, F, eps)) { return iter; } // The inverse can be called from within OpenACC regions without any issue, as opposed to diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index d944f9446..206145efb 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -543,15 +543,15 @@ void SympySolverVisitor::visit_derivative_block(ast::DerivativeBlock& node) { pre_solve_statements.push_back(std::move(expression)); } // replace ODE with Euler equation - eq = x; + eq = "("; + eq.append(x); eq.append(x_array_index); - eq.append(" = "); + eq.append(" - "); eq.append(old_x); - eq.append(" + "); + eq.append(") / "); eq.append(codegen::naming::NTHREAD_DT_VARIABLE); - eq.append(" * ("); + eq.append(" = "); eq.append(dxdt); - eq.append(")"); logger->debug("SympySolverVisitor :: -> constructed Euler eq: {}", eq); } } diff --git a/test/unit/visitor/sympy_solver.cpp b/test/unit/visitor/sympy_solver.cpp index 5a0e4c42b..486a24311 100644 --- a/test/unit/visitor/sympy_solver.cpp +++ b/test/unit/visitor/sympy_solver.cpp @@ -647,8 +647,8 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = m }{ - nmodl_eigen_f[0] = (-nmodl_eigen_x[0]*dt+dt*mInf+mTau*(-nmodl_eigen_x[0]+old_m))/mTau - nmodl_eigen_j[0] = -(dt+mTau)/mTau + nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+mInf)+mTau*(-nmodl_eigen_x[0]+old_m))/(dt*mTau) + nmodl_eigen_j[0] = (-dt-mTau)/(dt*mTau) }{ m = nmodl_eigen_x[0] }{ @@ -686,11 +686,11 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_y + nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_y)/dt nmodl_eigen_j[0] = 0 - nmodl_eigen_j[2] = -1.0 - nmodl_eigen_f[1] = -nmodl_eigen_x[0]+b*dt+old_x - nmodl_eigen_j[1] = -1.0 + nmodl_eigen_j[2] = -1/dt + nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_x)/dt + nmodl_eigen_j[1] = -1/dt nmodl_eigen_j[3] = 0 }{ x = nmodl_eigen_x[0] @@ -730,11 +730,11 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = M[0] nmodl_eigen_x[1] = M[1] }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_M_1 + nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_M_1)/dt nmodl_eigen_j[0] = 0 - nmodl_eigen_j[2] = -1.0 - nmodl_eigen_f[1] = -nmodl_eigen_x[0]+b*dt+old_M_0 - nmodl_eigen_j[1] = -1.0 + nmodl_eigen_j[2] = -1/dt + nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_M_0)/dt + nmodl_eigen_j[1] = -1/dt nmodl_eigen_j[3] = 0 }{ M[0] = nmodl_eigen_x[0] @@ -775,13 +775,13 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x - nmodl_eigen_j[0] = -1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt + nmodl_eigen_j[0] = -1/dt nmodl_eigen_j[2] = 0 b = b+1 - nmodl_eigen_f[1] = -nmodl_eigen_x[1]+b*dt+old_y + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt nmodl_eigen_j[1] = 0 - nmodl_eigen_j[3] = -1.0 + nmodl_eigen_j[3] = -1/dt }{ x = nmodl_eigen_x[0] y = nmodl_eigen_x[1] @@ -853,12 +853,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x - nmodl_eigen_j[0] = -1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt + nmodl_eigen_j[0] = -1/dt nmodl_eigen_j[2] = 0 - nmodl_eigen_f[1] = -nmodl_eigen_x[1]+b*dt+old_y + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt nmodl_eigen_j[1] = 0 - nmodl_eigen_j[3] = -1.0 + nmodl_eigen_j[3] = -1/dt }{ x = nmodl_eigen_x[0] y = nmodl_eigen_x[1] @@ -903,15 +903,15 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x - nmodl_eigen_j[0] = -1.0 - nmodl_eigen_j[2] = a*dt + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt + nmodl_eigen_j[0] = -1/dt + nmodl_eigen_j[2] = a IF (b == 1) { a = a+1 } - nmodl_eigen_f[1] = nmodl_eigen_x[0]*dt+nmodl_eigen_x[1]*a*dt-nmodl_eigen_x[1]+old_y - nmodl_eigen_j[1] = dt - nmodl_eigen_j[3] = a*dt-1.0 + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt + nmodl_eigen_j[1] = 1.0 + nmodl_eigen_j[3] = a-1/dt }{ x = nmodl_eigen_x[0] y = nmodl_eigen_x[1] @@ -929,15 +929,15 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x - nmodl_eigen_j[0] = -1.0 - nmodl_eigen_j[2] = a*dt + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt + nmodl_eigen_j[0] = -1/dt + nmodl_eigen_j[2] = a IF (b == 1) { a = a+1 } - nmodl_eigen_f[1] = nmodl_eigen_x[0]*dt+nmodl_eigen_x[1]*a*dt-nmodl_eigen_x[1]+old_y - nmodl_eigen_j[1] = dt - nmodl_eigen_j[3] = a*dt-1.0 + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt + nmodl_eigen_j[1] = 1.0 + nmodl_eigen_j[3] = a-1/dt }{ x = nmodl_eigen_x[0] y = nmodl_eigen_x[1] @@ -984,18 +984,18 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[1] = y nmodl_eigen_x[2] = z }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x - nmodl_eigen_j[0] = -1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt + nmodl_eigen_j[0] = -1/dt nmodl_eigen_j[3] = 0 - nmodl_eigen_j[6] = a*dt - nmodl_eigen_f[1] = 2.0*nmodl_eigen_x[0]*dt-nmodl_eigen_x[1]+c*dt+old_y - nmodl_eigen_j[1] = 2.0*dt - nmodl_eigen_j[4] = -1.0 + nmodl_eigen_j[6] = a + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt + nmodl_eigen_j[1] = 2.0 + nmodl_eigen_j[4] = -1/dt nmodl_eigen_j[7] = 0 - nmodl_eigen_f[2] = -nmodl_eigen_x[1]*dt+nmodl_eigen_x[2]*d*dt-nmodl_eigen_x[2]+old_z + nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt nmodl_eigen_j[2] = 0 - nmodl_eigen_j[5] = -dt - nmodl_eigen_j[8] = d*dt-1.0 + nmodl_eigen_j[5] = -1.0 + nmodl_eigen_j[8] = d-1/dt }{ x = nmodl_eigen_x[0] y = nmodl_eigen_x[1] @@ -1016,18 +1016,18 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[1] = y nmodl_eigen_x[2] = z }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x - nmodl_eigen_j[0] = -1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt + nmodl_eigen_j[0] = -1/dt nmodl_eigen_j[3] = 0 - nmodl_eigen_j[6] = a*dt - nmodl_eigen_f[1] = 2.0*nmodl_eigen_x[0]*dt-nmodl_eigen_x[1]+c*dt+old_y - nmodl_eigen_j[1] = 2.0*dt - nmodl_eigen_j[4] = -1.0 + nmodl_eigen_j[6] = a + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt + nmodl_eigen_j[1] = 2.0 + nmodl_eigen_j[4] = -1/dt nmodl_eigen_j[7] = 0 - nmodl_eigen_f[2] = -nmodl_eigen_x[1]*dt+nmodl_eigen_x[2]*d*dt-nmodl_eigen_x[2]+old_z + nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt nmodl_eigen_j[2] = 0 - nmodl_eigen_j[5] = -dt - nmodl_eigen_j[8] = d*dt-1.0 + nmodl_eigen_j[5] = -1.0 + nmodl_eigen_j[8] = d-1/dt }{ x = nmodl_eigen_x[0] y = nmodl_eigen_x[1] @@ -1070,12 +1070,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = mc nmodl_eigen_x[1] = m }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc - nmodl_eigen_j[0] = -a*dt-1.0 - nmodl_eigen_j[2] = b*dt - nmodl_eigen_f[1] = nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[1]*b*dt-nmodl_eigen_x[1]+old_m - nmodl_eigen_j[1] = a*dt - nmodl_eigen_j[3] = -b*dt-1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt + nmodl_eigen_j[0] = -a-1/dt + nmodl_eigen_j[2] = b + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt + nmodl_eigen_j[1] = a + nmodl_eigen_j[3] = -b-1/dt }{ mc = nmodl_eigen_x[0] m = nmodl_eigen_x[1] @@ -1113,9 +1113,9 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = mc nmodl_eigen_x[1] = m }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc - nmodl_eigen_j[0] = -a*dt-1.0 - nmodl_eigen_j[2] = b*dt + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt + nmodl_eigen_j[0] = -a-1/dt + nmodl_eigen_j[2] = b nmodl_eigen_f[1] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]+1.0 nmodl_eigen_j[1] = -1.0 nmodl_eigen_j[3] = -1.0 @@ -1159,12 +1159,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = mc nmodl_eigen_x[1] = m }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc - nmodl_eigen_j[0] = -a*dt-1.0 - nmodl_eigen_j[2] = b*dt - nmodl_eigen_f[1] = nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[1]*b*dt-nmodl_eigen_x[1]+old_m - nmodl_eigen_j[1] = a*dt - nmodl_eigen_j[3] = -b*dt-1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt + nmodl_eigen_j[0] = -a-1/dt + nmodl_eigen_j[2] = b + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt + nmodl_eigen_j[1] = a + nmodl_eigen_j[3] = -b-1/dt }{ mc = nmodl_eigen_x[0] m = nmodl_eigen_x[1] @@ -1213,16 +1213,16 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[3] = p0 nmodl_eigen_x[4] = p1 }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]*alpha*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*beta*dt+old_c1 - nmodl_eigen_j[0] = -alpha*dt-1.0 - nmodl_eigen_j[5] = beta*dt + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*alpha+nmodl_eigen_x[1]*beta)+old_c1)/dt + nmodl_eigen_j[0] = -alpha-1/dt + nmodl_eigen_j[5] = beta nmodl_eigen_j[10] = 0 nmodl_eigen_j[15] = 0 nmodl_eigen_j[20] = 0 - nmodl_eigen_f[1] = nmodl_eigen_x[0]*alpha*dt-nmodl_eigen_x[1]*beta*dt-nmodl_eigen_x[1]*dt*k3p-nmodl_eigen_x[1]+nmodl_eigen_x[2]*dt*k4+old_o1 - nmodl_eigen_j[1] = alpha*dt - nmodl_eigen_j[6] = -beta*dt-dt*k3p-1.0 - nmodl_eigen_j[11] = dt*k4 + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*alpha-nmodl_eigen_x[1]*beta-nmodl_eigen_x[1]*k3p+nmodl_eigen_x[2]*k4)+old_o1)/dt + nmodl_eigen_j[1] = alpha + nmodl_eigen_j[6] = -beta-k3p-1/dt + nmodl_eigen_j[11] = k4 nmodl_eigen_j[16] = 0 nmodl_eigen_j[21] = 0 nmodl_eigen_f[2] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]+1.0 @@ -1231,12 +1231,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_j[12] = -1.0 nmodl_eigen_j[17] = 0 nmodl_eigen_j[22] = 0 - nmodl_eigen_f[3] = -nmodl_eigen_x[3]*dt*k1ca-nmodl_eigen_x[3]+nmodl_eigen_x[4]*dt*k2+old_p0 + nmodl_eigen_f[3] = (-nmodl_eigen_x[3]+dt*(-nmodl_eigen_x[3]*k1ca+nmodl_eigen_x[4]*k2)+old_p0)/dt nmodl_eigen_j[3] = 0 nmodl_eigen_j[8] = 0 nmodl_eigen_j[13] = 0 - nmodl_eigen_j[18] = -dt*k1ca-1.0 - nmodl_eigen_j[23] = dt*k2 + nmodl_eigen_j[18] = -k1ca-1/dt + nmodl_eigen_j[23] = k2 nmodl_eigen_f[4] = -nmodl_eigen_x[3]-nmodl_eigen_x[4]+1.0 nmodl_eigen_j[4] = 0 nmodl_eigen_j[9] = 0 @@ -1286,8 +1286,8 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = W[0] }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0 - nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt + nmodl_eigen_j[0] = -A[0]+B[0]-1/dt }{ W[0] = nmodl_eigen_x[0] }{ @@ -1328,12 +1328,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = M[0] nmodl_eigen_x[1] = M[1] }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]-nmodl_eigen_x[0]+nmodl_eigen_x[1]*dt*B[0]+old_M_0 - nmodl_eigen_j[0] = -dt*A[0]-1.0 - nmodl_eigen_j[2] = dt*B[0] - nmodl_eigen_f[1] = nmodl_eigen_x[0]*dt*A[1]-nmodl_eigen_x[1]*dt*B[1]-nmodl_eigen_x[1]+old_M_1 - nmodl_eigen_j[1] = dt*A[1] - nmodl_eigen_j[3] = -dt*B[1]-1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[1]*B[0])+old_M_0)/dt + nmodl_eigen_j[0] = -A[0]-1/dt + nmodl_eigen_j[2] = B[0] + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*A[1]-nmodl_eigen_x[1]*B[1])+old_M_1)/dt + nmodl_eigen_j[1] = A[1] + nmodl_eigen_j[3] = -B[1]-1/dt }{ M[0] = nmodl_eigen_x[0] M[1] = nmodl_eigen_x[1] @@ -1372,8 +1372,8 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = W[0] }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0 - nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt + nmodl_eigen_j[0] = -A[0]+B[0]-1/dt }{ W[0] = nmodl_eigen_x[0] }{ @@ -1416,18 +1416,18 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[1] = h nmodl_eigen_x[2] = n }{ - nmodl_eigen_f[0] = (-nmodl_eigen_x[0]*dt+dt*minf+mtau*(-nmodl_eigen_x[0]-3.0*nmodl_eigen_x[1]*dt+old_m))/mtau - nmodl_eigen_f[1] = (-nmodl_eigen_x[1]*dt+dt*hinf+htau*(pow(nmodl_eigen_x[0], 2)*dt-nmodl_eigen_x[1]+old_h))/htau - nmodl_eigen_f[2] = (-nmodl_eigen_x[2]*dt+dt*ninf+ntau*(-nmodl_eigen_x[2]+old_n))/ntau - nmodl_eigen_j[0] = -(dt+mtau)/mtau - nmodl_eigen_j[3] = -3.0*dt + nmodl_eigen_f[0] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt-3.0*nmodl_eigen_x[1]+minf/mtau+old_m/dt + nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau) + nmodl_eigen_j[3] = -3.0 nmodl_eigen_j[6] = 0 - nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]*dt - nmodl_eigen_j[4] = -(dt+htau)/htau + nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt + nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0] + nmodl_eigen_j[4] = (-dt-htau)/(dt*htau) nmodl_eigen_j[7] = 0 + nmodl_eigen_f[2] = (dt*(-nmodl_eigen_x[2]+ninf)+ntau*(-nmodl_eigen_x[2]+old_n))/(dt*ntau) nmodl_eigen_j[2] = 0 nmodl_eigen_j[5] = 0 - nmodl_eigen_j[8] = -(dt+ntau)/ntau + nmodl_eigen_j[8] = (-dt-ntau)/(dt*ntau) }{ m = nmodl_eigen_x[0] h = nmodl_eigen_x[1] @@ -1474,12 +1474,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = m nmodl_eigen_x[1] = h }{ - nmodl_eigen_f[0] = (-nmodl_eigen_x[0]*dt+dt*minf+mtau*(-nmodl_eigen_x[0]+old_m))/mtau - nmodl_eigen_f[1] = (-nmodl_eigen_x[1]*dt+dt*hinf+htau*(pow(nmodl_eigen_x[0], 2)*dt-nmodl_eigen_x[1]+old_h))/htau - nmodl_eigen_j[0] = -(dt+mtau)/mtau + nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+minf)+mtau*(-nmodl_eigen_x[0]+old_m))/(dt*mtau) + nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau) nmodl_eigen_j[2] = 0 - nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]*dt - nmodl_eigen_j[3] = -(dt+htau)/htau + nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau- nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt + nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0] + nmodl_eigen_j[3] = (-dt-htau)/(dt*htau) }{ m = nmodl_eigen_x[0] h = nmodl_eigen_x[1] @@ -1497,12 +1497,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = m nmodl_eigen_x[1] = h }{ - nmodl_eigen_f[0] = (-nmodl_eigen_x[1]*dt+dt*hinf+htau*(pow(nmodl_eigen_x[0], 2)*dt-nmodl_eigen_x[1]+old_h))/htau - nmodl_eigen_f[1] = (-nmodl_eigen_x[0]*dt+dt*minf+mtau*(-nmodl_eigen_x[0]+nmodl_eigen_x[1]*dt+old_m))/mtau - nmodl_eigen_j[0] = 2.0*nmodl_eigen_x[0]*dt - nmodl_eigen_j[2] = -(dt+htau)/htau - nmodl_eigen_j[1] = -(dt+mtau)/mtau - nmodl_eigen_j[3] = dt + nmodl_eigen_f[0] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt + nmodl_eigen_j[0] = 2.0*nmodl_eigen_x[0] + nmodl_eigen_j[2] = (-dt-htau)/(dt*htau) + nmodl_eigen_f[1] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt+nmodl_eigen_x[1]+minf/mtau+old_m/dt + nmodl_eigen_j[1] = (-dt-mtau)/(dt*mtau) + nmodl_eigen_j[3] = 1.0 }{ m = nmodl_eigen_x[0] h = nmodl_eigen_x[1] @@ -1862,12 +1862,12 @@ SCENARIO("Solve KINETIC block using SympySolver Visitor", "[visitor][solver][sym nmodl_eigen_x[0] = C1 nmodl_eigen_x[1] = C2 }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*kf0_-nmodl_eigen_x[0]+nmodl_eigen_x[1]*dt*kb0_+old_C1 - nmodl_eigen_j[0] = -dt*kf0_-1.0 - nmodl_eigen_j[2] = dt*kb0_ - nmodl_eigen_f[1] = nmodl_eigen_x[0]*dt*kf0_-nmodl_eigen_x[1]*dt*kb0_-nmodl_eigen_x[1]+old_C2 - nmodl_eigen_j[1] = dt*kf0_ - nmodl_eigen_j[3] = -dt*kb0_-1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt + nmodl_eigen_j[0] = -kf0_-1/dt + nmodl_eigen_j[2] = kb0_ + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt + nmodl_eigen_j[1] = kf0_ + nmodl_eigen_j[3] = -kb0_-1/dt }{ C1 = nmodl_eigen_x[0] C2 = nmodl_eigen_x[1] @@ -1904,20 +1904,20 @@ SCENARIO("Solve KINETIC block using SympySolver Visitor", "[visitor][solver][sym EIGEN_NEWTON_SOLVE[2]{ LOCAL kf0_, kb0_, old_C1, old_C2 }{ - kb0_ = lowergamma(v) kf0_ = beta(v) + kb0_ = lowergamma(v) old_C1 = C1 old_C2 = C2 }{ nmodl_eigen_x[0] = C1 nmodl_eigen_x[1] = C2 }{ - nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*kf0_-nmodl_eigen_x[0]+nmodl_eigen_x[1]*dt*kb0_+old_C1 - nmodl_eigen_j[0] = -dt*kf0_-1.0 - nmodl_eigen_j[2] = dt*kb0_ - nmodl_eigen_f[1] = nmodl_eigen_x[0]*dt*kf0_-nmodl_eigen_x[1]*dt*kb0_-nmodl_eigen_x[1]+old_C2 - nmodl_eigen_j[1] = dt*kf0_ - nmodl_eigen_j[3] = -dt*kb0_-1.0 + nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt + nmodl_eigen_j[0] = -kf0_-1/dt + nmodl_eigen_j[2] = kb0_ + nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt + nmodl_eigen_j[1] = kf0_ + nmodl_eigen_j[3] = -kb0_-1/dt }{ C1 = nmodl_eigen_x[0] C2 = nmodl_eigen_x[1] diff --git a/test/usecases/CMakeLists.txt b/test/usecases/CMakeLists.txt index cc08856bf..ba5bf0fc6 100644 --- a/test/usecases/CMakeLists.txt +++ b/test/usecases/CMakeLists.txt @@ -21,6 +21,7 @@ set(NMODL_USECASE_DIRS random suffix state + steady_state table useion at_time) diff --git a/test/usecases/steady_state/minipump.mod b/test/usecases/steady_state/minipump.mod new file mode 100644 index 000000000..2027aecc8 --- /dev/null +++ b/test/usecases/steady_state/minipump.mod @@ -0,0 +1,43 @@ +NEURON { + SUFFIX minipump +} + +PARAMETER { + volA = 1e9 + volB = 1e9 + volC = 13.0 + kf = 3.0 + kb = 4.0 + + run_steady_state = 0.0 +} + +STATE { + X + Y + Z +} + +INITIAL { + X = 40.0 + Y = 8.0 + Z = 1.0 + + if(run_steady_state > 0.0) { + SOLVE state STEADYSTATE sparse + } +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +KINETIC state { + COMPARTMENT volA {X} + COMPARTMENT volB {Y} + COMPARTMENT volC {Z} + + ~ X + Y <-> Z (kf, kb) + + CONSERVE Y + Z = 8.0*volB + 1.0*volC +} diff --git a/test/usecases/steady_state/test_minipump.py b/test/usecases/steady_state/test_minipump.py new file mode 100644 index 000000000..4521bb92a --- /dev/null +++ b/test/usecases/steady_state/test_minipump.py @@ -0,0 +1,86 @@ +import sys +import pickle + +import numpy as np + +from neuron import h, gui + + +def run(steady_state): + s = h.Section() + + s.insert("minipump") + s.diam = 1.0 + + t_hoc = h.Vector().record(h._ref_t) + X_hoc = h.Vector().record(s(0.5).minipump._ref_X) + Y_hoc = h.Vector().record(s(0.5).minipump._ref_Y) + Z_hoc = h.Vector().record(s(0.5).minipump._ref_Z) + + h.run_steady_state_minipump = 1.0 if steady_state else 0.0 + + h.stdinit() + h.continuerun(1.0) + + t = np.array(t_hoc.as_numpy()) + X = np.array(X_hoc.as_numpy()) + Y = np.array(Y_hoc.as_numpy()) + Z = np.array(Z_hoc.as_numpy()) + + return t, X, Y, Z + + +def traces_filename(steady_state): + return "test_minipump{}.pkl".format("-steady_state" if steady_state else "") + + +def save_traces(t, X, Y, Z, steady_state): + with open(traces_filename(steady_state), "bw") as f: + pickle.dump({"t": t, "X": X, "Y": Y, "Z": Z}, f) + + +def load_traces(steady_state): + with open(traces_filename(steady_state), "br") as f: + d = pickle.load(f) + + return d["t"], d["X"], d["Y"], d["Z"] + + +def assert_almost_equal(actual, expected, rtol): + decimal = np.ceil(-np.log10(rtol * np.max(expected))) + np.testing.assert_almost_equal(actual, expected, decimal=decimal) + + +def check_traces(t, X, Y, Z, steady_state): + if len(sys.argv) < 2: + return + + codegen = sys.argv[1] + if codegen == "nocmodl": + save_traces(t, X, Y, Z, steady_state) + + else: + t_ref, X_ref, Y_ref, Z_ref = load_traces(steady_state) + + assert_almost_equal(t, t_ref, rtol=1e-8) + assert_almost_equal(X, X_ref, rtol=1e-8) + assert_almost_equal(Y, Y_ref, rtol=1e-8) + assert_almost_equal(Z, Z_ref, rtol=1e-8) + + +def check_solution(steady_state): + t, X, Y, Z = run(steady_state) + check_traces(t, X, Y, Z, steady_state=steady_state) + + +def test_steady_state(): + check_solution(steady_state=True) + + +def test_no_steady_state(): + check_solution(steady_state=False) + + +if __name__ == "__main__": + test_steady_state() + test_no_steady_state()