Skip to content

Commit

Permalink
Add PI-DeepONet example for 1D Poisson equation.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed May 15, 2023
1 parent 1604348 commit 0877e11
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 1 deletion.
8 changes: 7 additions & 1 deletion deepxde/data/function_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from sklearn import gaussian_process as gp

from .. import config
from ..utils import isclose
from ..utils import isclose, return_tensor



class FunctionSpace(abc.ABC):
Expand Down Expand Up @@ -90,6 +91,7 @@ def random(self, size):
def eval_one(self, feature, x):
return np.dot(feature, x ** np.arange(self.N))

@return_tensor
def eval_batch(self, features, xs):
mat = np.ones((self.N, len(xs)))
for i in range(1, self.N):
Expand Down Expand Up @@ -119,6 +121,7 @@ def random(self, size):
def eval_one(self, feature, x):
return np.polynomial.chebyshev.chebval(2 * x - 1, feature)

@return_tensor
def eval_batch(self, features, xs):
return np.polynomial.chebyshev.chebval(2 * np.ravel(xs) - 1, features.T)

Expand Down Expand Up @@ -166,6 +169,7 @@ def eval_one(self, feature, x):
)
return f(x)

@return_tensor
def eval_batch(self, features, xs):
if self.interp == "linear":
return np.vstack([np.interp(xs, np.ravel(self.x), y).T for y in features])
Expand Down Expand Up @@ -224,6 +228,7 @@ def eval_one(self, feature, x):
eigfun = [f(x) for f in self.eigfun]
return np.sum(eigfun * feature)

@return_tensor
def eval_batch(self, features, xs):
eigfun = np.array([np.ravel(f(xs)) for f in self.eigfun])
return np.dot(features, eigfun)
Expand Down Expand Up @@ -283,6 +288,7 @@ def eval_one(self, feature, x):
y = np.reshape(feature, (self.N, self.N))
return interpolate.interpn((self.x, self.y), y, x, method=self.interp)[0]

@return_tensor
def eval_batch(self, features, xs):
points = (self.x, self.y)
ys = np.reshape(features, (-1, self.N, self.N))
Expand Down
Binary file added docs/demos/operator/pideeponet_poisson1d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
124 changes: 124 additions & 0 deletions docs/demos/operator/pideeponet_poisson1d.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
Physics-informed DeepONet for 1D Poisson equation
=================================================

Problem setup
-------------

We will learn the solution operator

.. math:: G: f \mapsto u

for the one-dimensional Poisson problem

.. math:: u''(x) = f(x), \qquad x \in [0, 1],

with zero Dirichlet boundary conditions :math:`u(0) = u(1) = 0`.

The source term :math:`f` is supposed to be an arbitrary continuous function.


Implementation
--------------

The solution operator can be learned by training a physics-informed DeepONet.

First, we define the PDE with boundary conditions and the domain:

.. code-block:: python
def equation(x, y, f):
dy_xx = dde.grad.hessian(y, x)
return -dy_xx - f
geom = dde.geometry.Interval(0, 1)
def u_boundary(_):
return 0
def boundary(_, on_boundary):
return on_boundary
bc = dde.icbc.DirichletBC(geom, u_boundary, boundary)
pde = dde.data.PDE(
geom,
equation,
bc,
num_domain=100,
num_boundary=2
)
Next, we specify the function space for :math:`f` and the corresponding evaluation points.
For this example, we use the ``dde.data.PowerSeries`` to get the function space
of polynomials of degree three.
Together with the PDE, the function space is used to define a
PDEOperator ``dde.data.PDEOperatorCartesianProd`` that incorporates the PDE into
the loss function.

.. code-block:: python
degree = 3
space = dde.data.PowerSeries(N=degree+1)
num_eval_points = 10
evaluation_points = geom.uniform_points(num_eval_points, boundary=True)
pde_op = dde.data.PDEOperatorCartesianProd(
pde,
space,
evaluation_points,
num_function=100,
)
The DeepONet can be defined using ``dde.nn.DeepONetCartesianProd``.
The branch net is chosen as a fully connected neural network of size ``[m, 32, p]`` where ``p=32``
and the trunk net is a fully connected neural network of size ``[dim_x, 32, p]``.

.. code-block:: python
dim_x = 1
p = 32
net = dde.nn.DeepONetCartesianProd(
[num_eval_points, 32, p],
[dim_x, 32, p],
activation="tanh",
kernel_initializer="Glorot normal",
)
We define the ``Model`` and train it with either L-BFGS (pytorch) or Adam:

.. code-block:: python
model = dde.Model(pde_op, net)
if dde.backend.backend_name == "pytorch":
dde.optimizers.set_LBFGS_options(maxiter=1000)
model.compile("L-BFGS", metrics=["l2 relative error"])
model.train()
else:
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
model.train(iterations=10000)
Finally, the trained model can be used to predict the solution of the Poisson
equation. We sample the solution for three random representations of :math:`f`.

.. code-block:: python
n = 3
features = space.random(n)
fx = space.eval_batch(features, evaluation_points)
x = geom.uniform_points(100, boundary=True)
y = model.predict((fx, x))
![](pideeponet_poisson1d.png)


Complete code
-------------

.. literalinclude:: ../../../examples/operator/pideeponet_1d_poisson.py
:language: python
102 changes: 102 additions & 0 deletions examples/operator/pideeponet_poisson1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
import numpy as np
import deepxde as dde
import matplotlib.pyplot as plt


# Poisson equation: -u_xx = f
def equation(x, y, f):
dy_xx = dde.grad.hessian(y, x)
return -dy_xx - f


# Domain is interval [0, 1]
geom = dde.geometry.Interval(0, 1)


# Zero Dirichlet BC
def u_boundary(_):
return 0


def boundary(_, on_boundary):
return on_boundary


bc = dde.icbc.DirichletBC(geom, u_boundary, boundary)

# Define PDE
pde = dde.data.PDE(
geom,
equation,
bc,
num_domain=100,
num_boundary=2
)

# Function space for f(x) are polynomials
degree = 3
space = dde.data.PowerSeries(N=degree+1)

# Choose evaluation points
num_eval_points = 10
evaluation_points = geom.uniform_points(num_eval_points, boundary=True)

# Define PDE operator
pde_op = dde.data.PDEOperatorCartesianProd(
pde,
space,
evaluation_points,
num_function=100,
)

# Setup DeepONet
dim_x = 1
p = 32
net = dde.nn.DeepONetCartesianProd(
[num_eval_points, 32, p],
[dim_x, 32, p],
activation="tanh",
kernel_initializer="Glorot normal",
)

# Define and train model
model = dde.Model(pde_op, net)

if dde.backend.backend_name == "pytorch":
dde.optimizers.set_LBFGS_options(maxiter=1000)
model.compile("L-BFGS")
model.train()
else:
model.compile("adam", lr=0.001)
model.train(iterations=10000)

# Plot realisations of f(x)
n = 3
features = space.random(n)
fx = space.eval_batch(features, evaluation_points)

x = geom.uniform_points(100, boundary=True)
y = model.predict((fx, x))

# Setup figure
fig = plt.figure(figsize=(7, 8))
plt.subplot(2, 1, 1)
plt.title("Poisson equation: Source term f(x) and solution u(x)")
plt.ylabel("f(x)")
z = np.zeros_like(x)
plt.plot(x, z, 'k-', alpha=0.1)

# Plot source term f(x)
for i in range(n):
plt.plot(evaluation_points, fx[i], '--')

# Plot solution u(x)
plt.subplot(2, 1, 2)
plt.ylabel("u(x)")
plt.plot(x, z, 'k-', alpha=0.1)
for i in range(n):
plt.plot(x, y[i], '-')
plt.xlabel("x")

plt.show()

0 comments on commit 0877e11

Please sign in to comment.