Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NNCG to optimizers submodule #1661

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
74 changes: 73 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,27 @@ def closure():
if self.lr_scheduler is not None:
self.lr_scheduler.step()

def train_step_nncg(inputs, targets, auxiliary_vars):
def closure():
return get_loss_grad_nncg(inputs, targets, auxiliary_vars)

self.opt.step(closure)

def get_loss_grad_nncg(inputs, targets, auxiliary_vars):
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
grad_tuple = torch.autograd.grad(total_loss, trainable_variables,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? In the code of L-BFGS https://pytorch.org/docs/stable/_modules/torch/optim/lbfgs.html#LBFGS, this seems to be computed in step() by self._gather_flat_grad().

create_graph=True)
return total_loss, grad_tuple

# Callables
self.outputs = outputs
self.outputs_losses_train = outputs_losses_train
self.outputs_losses_test = outputs_losses_test
self.train_step = train_step
self.train_step_nncg = train_step_nncg
self.get_loss_grad_nncg = get_loss_grad_nncg

def _compile_jax(self, lr, loss_fn, decay):
"""jax"""
Expand Down Expand Up @@ -636,12 +652,22 @@ def train(
self._test()
self.callbacks.on_train_begin()
if optimizers.is_external_optimizer(self.opt_name):
if self.opt_name == "NNCG" and backend_name != "pytorch":
raise ValueError(
"The optimizer 'NNCG' is only supported for the backend PyTorch."
)
if backend_name == "tensorflow.compat.v1":
self._train_tensorflow_compat_v1_scipy(display_every)
elif backend_name == "tensorflow":
self._train_tensorflow_tfp()
elif backend_name == "pytorch":
self._train_pytorch_lbfgs()
if self.opt_name == "L-BFGS":
self._train_pytorch_lbfgs()
elif self.opt_name == "NNCG":
self._train_pytorch_nncg(iterations, display_every)
else:
raise ValueError("Only 'L-BFGS' and 'NNCG' are supported as \
external optimizers for PyTorch.")
elif backend_name == "paddle":
self._train_paddle_lbfgs()
else:
Expand Down Expand Up @@ -785,6 +811,52 @@ def _train_pytorch_lbfgs(self):
if self.stop_training:
break

def _train_pytorch_nncg(self, iterations, display_every):
# Loop over the iterations -- take inspiration from _train_pytorch_lbfgs and _train_sgd
for i in range(iterations):
# 1. Perform appropriate begin callbacks
self.callbacks.on_epoch_begin()
self.callbacks.on_batch_begin()

# 2. Update the preconditioner (if applicable)
# 2.1. We can check if the preconditioner is updated by making an
# option in NNCG_options called update_freq. Do the usual modular arithmetic
# from there
if i % optimizers.NNCG_options["updatefreq"] == 0:
self.opt.zero_grad()
# 2.2. How do we actually do this? Get the sum of the losses as in
# train_step(), and use torch.autograd.grad to get a gradient
_, grad_tuple = self.get_loss_grad_nncg(
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
)
# 2.3. Plug the gradient into the NNCG update_preconditioner function
# to perform the update
self.opt.update_preconditioner(grad_tuple)

# 3. Call the train step
self.train_step_nncg(
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
)

# 4. Use self._test() if needed
self.train_state.epoch += 1
self.train_state.step += 1
if self.train_state.step % display_every == 0 or i + 1 == iterations:
self._test()

# 5. Perform appropriate end callbacks
self.callbacks.on_batch_end()
self.callbacks.on_epoch_end()

# 6. Allow for training to stop (if self.stop_training)
if self.stop_training:
break


def _train_paddle_lbfgs(self):
prev_n_iter = 0

Expand Down
2 changes: 1 addition & 1 deletion deepxde/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib
import sys

from .config import LBFGS_options, set_LBFGS_options
from .config import LBFGS_options, set_LBFGS_options, NNCG_options, set_NNCG_options
from ..backend import backend_name


Expand Down
53 changes: 52 additions & 1 deletion deepxde/optimizers/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
__all__ = ["set_LBFGS_options", "set_hvd_opt_options"]
__all__ = ["set_LBFGS_options", "set_NNCG_options", "set_hvd_opt_options"]

from ..backend import backend_name
from ..config import hvd

LBFGS_options = {}
NNCG_options = {}
if hvd is not None:
hvd_opt_options = {}

Expand Down Expand Up @@ -59,6 +60,55 @@ def set_LBFGS_options(
LBFGS_options["maxfun"] = maxfun if maxfun is not None else int(maxiter * 1.25)
LBFGS_options["maxls"] = maxls

def set_NNCG_options(
lr=1,
rank=10,
mu=1e-4,
updatefreq=20,
chunksz=1,
cgtol=1e-16,
cgmaxiter=1000,
lsfun="armijo",
verbose=False
):
"""Sets the hyperparameters of NysNewtonCG (NNCG).

Args:
lr (float): `lr` (torch).
Learning rate (before line search).
rank (int): `rank` (torch).
Rank of preconditioner matrix used in preconditioned conjugate gradient.
mu (float): `mu` (torch).
Hessian damping parameter.
updatefreq (int): How often the preconditioner matrix in preconditioned
conjugate gradient is updated. This parameter is not directly used in NNCG,
instead it is used in _train_pytorch_nncg in deepxde/model.py.
chunksz (int): `chunk_size` (torch).
Number of Hessian-vector products to compute in parallel when constructing
preconditioner. If `chunk_size` is 1, the Hessian-vector products are
computed serially.
cgtol (float): `cg_tol` (torch).
Convergence tolerance for the conjugate gradient method. The iteration stops
when `||r||_2 <= cgtol`, where `r` is the residual. Note that this condition
is based on the absolute tolerance, not the relative tolerance.
cgmaxiter (int): `cg_max_iters` (torch).
Maximum number of iterations for the conjugate gradient method.
lsfun (str): `line_search_fn` (torch).
The line search function used to find the step size. The default value is
"armijo". The other option is None.
verbose (bool): `verbose` (torch).
If `True`, prints the eigenvalues of the Nyström approximation
of the Hessian.
"""
NNCG_options["lr"] = lr
NNCG_options["rank"] = rank
NNCG_options["mu"] = mu
NNCG_options["updatefreq"] = updatefreq
NNCG_options["chunksz"] = chunksz
NNCG_options["cgtol"] = cgtol
NNCG_options["cgmaxiter"] = cgmaxiter
NNCG_options["lsfun"] = lsfun
NNCG_options["verbose"] = verbose

def set_hvd_opt_options(
compression=None,
Expand Down Expand Up @@ -91,6 +141,7 @@ def set_hvd_opt_options(


set_LBFGS_options()
set_NNCG_options()
if hvd is not None:
set_hvd_opt_options()

Expand Down
Loading