Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow 1.x backend: layer-by-layer dropout rate setting for DeepONet #1792

Merged
merged 4 commits into from
Jul 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 80 additions & 14 deletions deepxde/nn/tensorflow_compat_v1/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ class DeepONet(NN):
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
net uses the activation `activation["trunk"]`, and the branch net uses
`activation["branch"]`.
dropout_rate (float): The dropout rate, between 0 and 1.
dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
same rate is used in both trunk and branch nets. If `dropout_rate`
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
The list length should match the length of `layer_size_trunk` - 1 for the
trunk net and `layer_size_branch` - 2 for the branch net.
trainable_branch: Boolean.
trainable_trunk: Boolean or a list of booleans.
num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1,
Expand Down Expand Up @@ -219,7 +225,31 @@ def __init__(
"stacked " + kernel_initializer
)
self.regularizer = regularizers.get(regularization)
self.dropout_rate = dropout_rate
if isinstance(dropout_rate, dict):
self.dropout_rate_branch = dropout_rate["branch"]
self.dropout_rate_trunk = dropout_rate["trunk"]
else:
self.dropout_rate_branch = self.dropout_rate_trunk = dropout_rate
if isinstance(self.dropout_rate_branch, list):
if not (len(layer_sizes_branch) - 2) == len(self.dropout_rate_branch):
raise ValueError(
"Number of dropout rates of branch net must be "
f"equal to {len(layer_sizes_branch) - 2}"
)
else:
self.dropout_rate_branch = [self.dropout_rate_branch] * (
len(layer_sizes_branch) - 2
)
if isinstance(self.dropout_rate_trunk, list):
if not (len(layer_sizes_trunk) - 1) == len(self.dropout_rate_trunk):
raise ValueError(
"Number of dropout rates of trunk net must be "
f"equal to {len(layer_sizes_trunk) - 1}"
)
else:
self.dropout_rate_trunk = [self.dropout_rate_trunk] * (
len(layer_sizes_trunk) - 1
)
self.use_bias = use_bias
self.stacked = stacked
self.trainable_branch = trainable_branch
Expand Down Expand Up @@ -303,9 +333,11 @@ def build_branch_net(self):
activation=self.activation_branch,
trainable=self.trainable_branch,
)
if self.dropout_rate > 0:
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func, rate=self.dropout_rate, training=self.training
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
y_func = self._stacked_dense(
y_func,
Expand All @@ -324,9 +356,11 @@ def build_branch_net(self):
regularizer=self.regularizer,
trainable=self.trainable_branch,
)
if self.dropout_rate > 0:
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func, rate=self.dropout_rate, training=self.training
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
y_func = self._dense(
y_func,
Expand All @@ -351,9 +385,9 @@ def build_trunk_net(self):
if isinstance(self.trainable_trunk, (list, tuple))
else self.trainable_trunk,
)
if self.dropout_rate > 0:
if self.dropout_rate_trunk[i - 1] > 0:
y_loc = tf.layers.dropout(
y_loc, rate=self.dropout_rate, training=self.training
y_loc, rate=self.dropout_rate_trunk[i - 1], training=self.training
)
return y_loc

Expand Down Expand Up @@ -454,7 +488,13 @@ class DeepONetCartesianProd(NN):
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
net uses the activation `activation["trunk"]`, and the branch net uses
`activation["branch"]`.
dropout_rate (float): The dropout rate, between 0 and 1.
dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
same rate is used in both trunk and branch nets. If `dropout_rate`
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
The list length should match the length of `layer_size_trunk` - 1 for the
trunk net and `layer_size_branch` - 2 for the branch net.
num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1,
`multi_output_strategy` below should be set.
multi_output_strategy (str or None): ``None``, "independent", "split_both", "split_branch" or
Expand Down Expand Up @@ -504,7 +544,31 @@ def __init__(
self.activation_branch = self.activation_trunk = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.regularizer = regularizers.get(regularization)
self.dropout_rate = dropout_rate
if isinstance(dropout_rate, dict):
self.dropout_rate_branch = dropout_rate["branch"]
self.dropout_rate_trunk = dropout_rate["trunk"]
else:
self.dropout_rate_branch = self.dropout_rate_trunk = dropout_rate
if isinstance(self.dropout_rate_branch, list):
if not (len(layer_size_branch) - 2) == len(self.dropout_rate_branch):
raise ValueError(
"Number of dropout rates of branch net must be "
f"equal to {len(layer_size_branch) - 2}"
)
else:
self.dropout_rate_branch = [self.dropout_rate_branch] * (
len(layer_size_branch) - 2
)
if isinstance(self.dropout_rate_trunk, list):
if not (len(layer_size_trunk) - 1) == len(self.dropout_rate_trunk):
raise ValueError(
"Number of dropout rates of trunk net must be "
f"equal to {len(layer_size_trunk) - 1}"
)
else:
self.dropout_rate_trunk = [self.dropout_rate_trunk] * (
len(layer_size_trunk) - 1
)
self._inputs = None

self.num_outputs = num_outputs
Expand Down Expand Up @@ -571,9 +635,11 @@ def build_branch_net(self):
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
)
if self.dropout_rate > 0:
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func, rate=self.dropout_rate, training=self.training
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
y_func = tf.layers.dense(
y_func,
Expand All @@ -596,9 +662,9 @@ def build_trunk_net(self):
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
)
if self.dropout_rate > 0:
if self.dropout_rate_trunk[i - 1] > 0:
y_loc = tf.layers.dropout(
y_loc, rate=self.dropout_rate, training=self.training
y_loc, rate=self.dropout_rate_trunk[i - 1], training=self.training
)
return y_loc

Expand Down
Loading