Skip to content

Commit

Permalink
add example and update faq to include mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
g-w1 committed Feb 11, 2024
1 parent 089ede5 commit 1352213
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/user/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ General usage
| **A**: `#5`_
- | **Q**: By default, DeepXDE uses ``float32``. How can I use ``float64``?
| **A**: `#28`_
- | **Q**: How can I use mixed precision training?
| **A**: Use ``dde.config.set_default_float("mixed")`` with the ``tensorflow`` or ``pytorch`` backends. See `this paper <https://arxiv.org/abs/2401.16645>`_ for more information.
- | **Q**: I want to set the global random seeds.
| **A**: `#353`_
- | **Q**: GPU.
Expand Down
55 changes: 55 additions & 0 deletions examples/pinn_forward/Burgers_mixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Backend supported: tensorflow, pytorch
The exact same as Burgers.py, but using mixed precision instead of float32.
This preserves accuracy while speeding up training (especially with larger training runs).
"""

import deepxde as dde
import numpy as np

dde.config.set_default_float("mixed")


def gen_testdata():
data = np.load("../dataset/Burgers.npz")
t, x, exact = data["t"], data["x"], data["usol"].T
xx, tt = np.meshgrid(x, t)
X = np.vstack((np.ravel(xx), np.ravel(tt))).T
y = exact.flatten()[:, None]
return X, y


def pde(x, y):
dy_x = dde.grad.jacobian(y, x, i=0, j=0)
dy_t = dde.grad.jacobian(y, x, i=0, j=1)
dy_xx = dde.grad.hessian(y, x, i=0, j=0)
return dy_t + y * dy_x - 0.01 / np.pi * dy_xx


geom = dde.geometry.Interval(-1, 1)
timedomain = dde.geometry.TimeDomain(0, 0.99)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)

bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary)
ic = dde.icbc.IC(
geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial
)

data = dde.data.TimePDE(
geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160
)
net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal")
model = dde.Model(data, net)

model.compile("adam", lr=1e-3)
losshistory, train_state = model.train(iterations=15000)
# We have to disable L-BFGS since it does not support mixed precision
# model.compile("L-BFGS")
# losshistory, train_state = model.train()
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

X, y_true = gen_testdata()
y_pred = model.predict(X)
f = model.predict(X, operator=pde)
print("Mean residual:", np.mean(np.absolute(f)))
print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred))
np.savetxt("test.dat", np.hstack((X, y_true, y_pred)))

0 comments on commit 1352213

Please sign in to comment.