Skip to content

Commit

Permalink
Reformat code, use L-BFGS for all backends.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Aug 2, 2023
1 parent b26515c commit c44269a
Showing 1 changed file with 10 additions and 21 deletions.
31 changes: 10 additions & 21 deletions examples/operator/pideeponet_poisson1d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
import numpy as np
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np


# Poisson equation: -u_xx = f
Expand All @@ -26,17 +26,11 @@ def boundary(_, on_boundary):
bc = dde.icbc.DirichletBC(geom, u_boundary, boundary)

# Define PDE
pde = dde.data.PDE(
geom,
equation,
bc,
num_domain=100,
num_boundary=2
)
pde = dde.data.PDE(geom, equation, bc, num_domain=100, num_boundary=2)

# Function space for f(x) are polynomials
degree = 3
space = dde.data.PowerSeries(N=degree+1)
space = dde.data.PowerSeries(N=degree + 1)

# Choose evaluation points
num_eval_points = 10
Expand All @@ -62,14 +56,9 @@ def boundary(_, on_boundary):

# Define and train model
model = dde.Model(pde_op, net)

if dde.backend.backend_name == "pytorch":
dde.optimizers.set_LBFGS_options(maxiter=1000)
model.compile("L-BFGS")
model.train()
else:
model.compile("adam", lr=0.001)
model.train(iterations=10000)
dde.optimizers.set_LBFGS_options(maxiter=1000)
model.compile("L-BFGS")
model.train()

# Plot realisations of f(x)
n = 3
Expand All @@ -85,18 +74,18 @@ def boundary(_, on_boundary):
plt.title("Poisson equation: Source term f(x) and solution u(x)")
plt.ylabel("f(x)")
z = np.zeros_like(x)
plt.plot(x, z, 'k-', alpha=0.1)
plt.plot(x, z, "k-", alpha=0.1)

# Plot source term f(x)
for i in range(n):
plt.plot(evaluation_points, fx[i], '--')
plt.plot(evaluation_points, fx[i], "--")

# Plot solution u(x)
plt.subplot(2, 1, 2)
plt.ylabel("u(x)")
plt.plot(x, z, 'k-', alpha=0.1)
plt.plot(x, z, "k-", alpha=0.1)
for i in range(n):
plt.plot(x, y[i], '-')
plt.plot(x, y[i], "-")
plt.xlabel("x")

plt.show()

0 comments on commit c44269a

Please sign in to comment.