diff --git a/deepxde/data/function_spaces.py b/deepxde/data/function_spaces.py index a86863216..ccb09a021 100644 --- a/deepxde/data/function_spaces.py +++ b/deepxde/data/function_spaces.py @@ -15,7 +15,8 @@ from sklearn import gaussian_process as gp from .. import config -from ..utils import isclose +from ..utils import isclose, return_tensor + class FunctionSpace(abc.ABC): @@ -90,6 +91,7 @@ def random(self, size): def eval_one(self, feature, x): return np.dot(feature, x ** np.arange(self.N)) + @return_tensor def eval_batch(self, features, xs): mat = np.ones((self.N, len(xs))) for i in range(1, self.N): @@ -119,6 +121,7 @@ def random(self, size): def eval_one(self, feature, x): return np.polynomial.chebyshev.chebval(2 * x - 1, feature) + @return_tensor def eval_batch(self, features, xs): return np.polynomial.chebyshev.chebval(2 * np.ravel(xs) - 1, features.T) @@ -166,6 +169,7 @@ def eval_one(self, feature, x): ) return f(x) + @return_tensor def eval_batch(self, features, xs): if self.interp == "linear": return np.vstack([np.interp(xs, np.ravel(self.x), y).T for y in features]) @@ -224,6 +228,7 @@ def eval_one(self, feature, x): eigfun = [f(x) for f in self.eigfun] return np.sum(eigfun * feature) + @return_tensor def eval_batch(self, features, xs): eigfun = np.array([np.ravel(f(xs)) for f in self.eigfun]) return np.dot(features, eigfun) @@ -283,6 +288,7 @@ def eval_one(self, feature, x): y = np.reshape(feature, (self.N, self.N)) return interpolate.interpn((self.x, self.y), y, x, method=self.interp)[0] + @return_tensor def eval_batch(self, features, xs): points = (self.x, self.y) ys = np.reshape(features, (-1, self.N, self.N)) diff --git a/docs/demos/operator/pideeponet_poisson1d.png b/docs/demos/operator/pideeponet_poisson1d.png new file mode 100644 index 000000000..e322a5e12 Binary files /dev/null and b/docs/demos/operator/pideeponet_poisson1d.png differ diff --git a/docs/demos/operator/pideeponet_poisson1d.rst b/docs/demos/operator/pideeponet_poisson1d.rst new file mode 100644 index 000000000..60ba2ce80 --- /dev/null +++ b/docs/demos/operator/pideeponet_poisson1d.rst @@ -0,0 +1,124 @@ +Physics-informed DeepONet for 1D Poisson equation +================================================= + +Problem setup +------------- + +We will learn the solution operator + +.. math:: G: f \mapsto u + +for the one-dimensional Poisson problem + +.. math:: u''(x) = f(x), \qquad x \in [0, 1], + +with zero Dirichlet boundary conditions :math:`u(0) = u(1) = 0`. + +The source term :math:`f` is supposed to be an arbitrary continuous function. + + +Implementation +-------------- + +The solution operator can be learned by training a physics-informed DeepONet. + +First, we define the PDE with boundary conditions and the domain: + +.. code-block:: python + + def equation(x, y, f): + dy_xx = dde.grad.hessian(y, x) + return -dy_xx - f + + geom = dde.geometry.Interval(0, 1) + + def u_boundary(_): + return 0 + + def boundary(_, on_boundary): + return on_boundary + + bc = dde.icbc.DirichletBC(geom, u_boundary, boundary) + + pde = dde.data.PDE( + geom, + equation, + bc, + num_domain=100, + num_boundary=2 + ) + + +Next, we specify the function space for :math:`f` and the corresponding evaluation points. +For this example, we use the ``dde.data.PowerSeries`` to get the function space +of polynomials of degree three. +Together with the PDE, the function space is used to define a +PDEOperator ``dde.data.PDEOperatorCartesianProd`` that incorporates the PDE into +the loss function. + +.. code-block:: python + + degree = 3 + space = dde.data.PowerSeries(N=degree+1) + + num_eval_points = 10 + evaluation_points = geom.uniform_points(num_eval_points, boundary=True) + + pde_op = dde.data.PDEOperatorCartesianProd( + pde, + space, + evaluation_points, + num_function=100, + ) + + +The DeepONet can be defined using ``dde.nn.DeepONetCartesianProd``. +The branch net is chosen as a fully connected neural network of size ``[m, 32, p]`` where ``p=32`` +and the trunk net is a fully connected neural network of size ``[dim_x, 32, p]``. + +.. code-block:: python + + dim_x = 1 + p = 32 + net = dde.nn.DeepONetCartesianProd( + [num_eval_points, 32, p], + [dim_x, 32, p], + activation="tanh", + kernel_initializer="Glorot normal", + ) + + +We define the ``Model`` and train it with either L-BFGS (pytorch) or Adam: + +.. code-block:: python + + model = dde.Model(pde_op, net) + if dde.backend.backend_name == "pytorch": + dde.optimizers.set_LBFGS_options(maxiter=1000) + model.compile("L-BFGS", metrics=["l2 relative error"]) + model.train() + else: + model.compile("adam", lr=0.001, metrics=["l2 relative error"]) + model.train(iterations=10000) + +Finally, the trained model can be used to predict the solution of the Poisson +equation. We sample the solution for three random representations of :math:`f`. + +.. code-block:: python + + n = 3 + features = space.random(n) + fx = space.eval_batch(features, evaluation_points) + + x = geom.uniform_points(100, boundary=True) + y = model.predict((fx, x)) + + +![](pideeponet_poisson1d.png) + + +Complete code +------------- + +.. literalinclude:: ../../../examples/operator/pideeponet_1d_poisson.py + :language: python diff --git a/examples/operator/pideeponet_poisson1d.py b/examples/operator/pideeponet_poisson1d.py new file mode 100755 index 000000000..0b296dc31 --- /dev/null +++ b/examples/operator/pideeponet_poisson1d.py @@ -0,0 +1,102 @@ +"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" +import numpy as np +import deepxde as dde +import matplotlib.pyplot as plt + + +# Poisson equation: -u_xx = f +def equation(x, y, f): + dy_xx = dde.grad.hessian(y, x) + return -dy_xx - f + + +# Domain is interval [0, 1] +geom = dde.geometry.Interval(0, 1) + + +# Zero Dirichlet BC +def u_boundary(_): + return 0 + + +def boundary(_, on_boundary): + return on_boundary + + +bc = dde.icbc.DirichletBC(geom, u_boundary, boundary) + +# Define PDE +pde = dde.data.PDE( + geom, + equation, + bc, + num_domain=100, + num_boundary=2 +) + +# Function space for f(x) are polynomials +degree = 3 +space = dde.data.PowerSeries(N=degree+1) + +# Choose evaluation points +num_eval_points = 10 +evaluation_points = geom.uniform_points(num_eval_points, boundary=True) + +# Define PDE operator +pde_op = dde.data.PDEOperatorCartesianProd( + pde, + space, + evaluation_points, + num_function=100, +) + +# Setup DeepONet +dim_x = 1 +p = 32 +net = dde.nn.DeepONetCartesianProd( + [num_eval_points, 32, p], + [dim_x, 32, p], + activation="tanh", + kernel_initializer="Glorot normal", +) + +# Define and train model +model = dde.Model(pde_op, net) + +if dde.backend.backend_name == "pytorch": + dde.optimizers.set_LBFGS_options(maxiter=1000) + model.compile("L-BFGS") + model.train() +else: + model.compile("adam", lr=0.001) + model.train(iterations=10000) + +# Plot realisations of f(x) +n = 3 +features = space.random(n) +fx = space.eval_batch(features, evaluation_points) + +x = geom.uniform_points(100, boundary=True) +y = model.predict((fx, x)) + +# Setup figure +fig = plt.figure(figsize=(7, 8)) +plt.subplot(2, 1, 1) +plt.title("Poisson equation: Source term f(x) and solution u(x)") +plt.ylabel("f(x)") +z = np.zeros_like(x) +plt.plot(x, z, 'k-', alpha=0.1) + +# Plot source term f(x) +for i in range(n): + plt.plot(evaluation_points, fx[i], '--') + +# Plot solution u(x) +plt.subplot(2, 1, 2) +plt.ylabel("u(x)") +plt.plot(x, z, 'k-', alpha=0.1) +for i in range(n): + plt.plot(x, y[i], '-') +plt.xlabel("x") + +plt.show()