Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pde operator backend #1379

Merged
merged 6 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions deepxde/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,30 @@ def to_numpy(input_tensor):
return input_tensor.numpy()


def concat(values, axis):
return tf.concat(values, axis)


def stack(values, axis):
return tf.stack(values, axis)


def expand_dims(tensor, axis):
return tf.expand_dims(tensor, axis)


def reverse(tensor, axis):
return tf.reverse(tensor, axis)


def roll(tensor, shift, axis):
return tf.roll(tensor, shift, axis)


def lgamma(x):
return tf.math.lgamma(x)


def elu(x):
return tf.nn.elu(x)

Expand Down Expand Up @@ -124,6 +148,10 @@ def tanh(x):
return tf.math.tanh(x)


def pow(x, y):
return tf.math.pow(x, y)


def mean(input_tensor, dim, keepdims=False):
return tf.math.reduce_mean(input_tensor, axis=dim, keepdims=keepdims)

Expand Down Expand Up @@ -178,3 +206,7 @@ def zeros(shape, dtype):

def zeros_like(input_tensor):
return tf.zeros_like(input_tensor)


def matmul(x, y):
return tf.linalg.matmul(x, y)
3 changes: 2 additions & 1 deletion deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func):
losses.append(losses_i)

losses = zip(*losses)
losses = [bkd.reduce_mean(bkd.as_tensor(l)) for l in losses]
# Use stack instead of as_tensor to keep the gradients.
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
return losses

def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
Expand Down
4 changes: 3 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def outputs(training, inputs):
return self.net(inputs)

def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
self.net.auxiliary_vars = auxiliary_vars
self.net.auxiliary_vars = None
if auxiliary_vars is not None:
samuelburbulla marked this conversation as resolved.
Show resolved Hide resolved
self.net.auxiliary_vars = torch.as_tensor(auxiliary_vars)
self.net.train(mode=training)
if isinstance(inputs, tuple):
inputs = tuple(
Expand Down
2 changes: 1 addition & 1 deletion examples/operator/advection_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, tensorflow"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion examples/operator/advection_aligned_pideeponet_2d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, tensorflow"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion examples/operator/antiderivative_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, tensorflow"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion examples/operator/diff_rec_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, tensorflow"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down