Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply "Newton error propagation based convergence check." #1419

Merged
merged 5 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1526,9 +1526,21 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) {

print_rename_state_vars();

if (!info.changed_dt.empty()) {
printer->fmt_line("double _save_prev_dt = {};",
get_variable_name(naming::NTHREAD_DT_VARIABLE));
printer->fmt_line("{} = {};",
get_variable_name(naming::NTHREAD_DT_VARIABLE),
info.changed_dt);
}

print_initial_block(info.initial_node);
printer->pop_block();

if (!info.changed_dt.empty()) {
printer->fmt_line("{} = _save_prev_dt;", get_variable_name(naming::NTHREAD_DT_VARIABLE));
}

printer->pop_block();
printer->pop_block();
}

Expand Down
39 changes: 30 additions & 9 deletions src/solver/newton/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,33 @@ namespace newton {
* @{
*/

static constexpr int MAX_ITER = 1e3;
static constexpr double EPS = 1e-12;
static constexpr int MAX_ITER = 50;
1uc marked this conversation as resolved.
Show resolved Hide resolved
static constexpr double EPS = 1e-13;

template <int N>
EIGEN_DEVICE_FUNC bool is_converged(const Eigen::Matrix<double, N, 1>& X,
const Eigen::Matrix<double, N, N>& J,
const Eigen::Matrix<double, N, 1>& F,
double eps) {
bool converged = true;
double square_eps = eps * eps;
for (Eigen::Index i = 0; i < N; ++i) {
double square_error = 0.0;
for (Eigen::Index j = 0; j < N; ++j) {
double JX = J(i, j) * X(j);
square_error += JX * JX;
}

if (F(i) * F(i) > square_eps * square_error) {
converged = false;
// The NVHPC is buggy and wont allow us to short-circuit.
#ifndef __NVCOMPILER
return converged;
#endif
}
}
return converged;
}

/**
* \brief Newton method with user-provided Jacobian
Expand All @@ -58,17 +83,14 @@ EIGEN_DEVICE_FUNC int newton_solver(Eigen::Matrix<double, N, 1>& X,
int max_iter = MAX_ITER) {
// Vector to store result of function F(X):
Eigen::Matrix<double, N, 1> F;
// Matrix to store jacobian of F(X):
// Matrix to store Jacobian of F(X):
Eigen::Matrix<double, N, N> J;
// Solver iteration count:
int iter = -1;
while (++iter < max_iter) {
// calculate F, J from X using user-supplied functor
functor(X, F, J);
// get error norm: here we use sqrt(|F|^2)
double error = F.norm();
if (error < eps) {
// we have converged: return iteration count
if (is_converged(X, J, F, eps)) {
return iter;
}
// In Eigen the default storage order is ColMajor.
Expand Down Expand Up @@ -109,8 +131,7 @@ EIGEN_DEVICE_FUNC int newton_solver_small_N(Eigen::Matrix<double, N, 1>& X,
int iter = -1;
while (++iter < max_iter) {
functor(X, F, J);
double error = F.norm();
if (error < eps) {
if (is_converged(X, J, F, eps)) {
return iter;
}
// The inverse can be called from within OpenACC regions without any issue, as opposed to
Expand Down
10 changes: 5 additions & 5 deletions src/visitors/sympy_solver_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,15 @@ void SympySolverVisitor::visit_derivative_block(ast::DerivativeBlock& node) {
pre_solve_statements.push_back(std::move(expression));
}
// replace ODE with Euler equation
eq = x;
eq = "(";
eq.append(x);
eq.append(x_array_index);
eq.append(" = ");
eq.append(" - ");
eq.append(old_x);
eq.append(" + ");
eq.append(") / ");
eq.append(codegen::naming::NTHREAD_DT_VARIABLE);
eq.append(" * (");
eq.append(" = ");
eq.append(dxdt);
eq.append(")");
logger->debug("SympySolverVisitor :: -> constructed Euler eq: {}", eq);
}
}
Expand Down
Loading
Loading