Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow 1.x backend: add dropout to DeepONet #1579

Merged
merged 3 commits into from
Jun 25, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions deepxde/nn/tensorflow_compat_v1/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class DeepONet(NN):
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
net uses the activation `activation["trunk"]`, and the branch net uses
`activation["branch"]`.
dropout_rate (float): The dropout rate, between 0 and 1.
trainable_branch: Boolean.
trainable_trunk: Boolean or a list of booleans.
num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1,
Expand Down Expand Up @@ -192,6 +193,7 @@ def __init__(
activation,
kernel_initializer,
regularization=None,
dropout_rate=0,
use_bias=True,
stacked=False,
trainable_branch=True,
Expand All @@ -217,6 +219,7 @@ def __init__(
"stacked " + kernel_initializer
)
self.regularizer = regularizers.get(regularization)
self.dropout_rate = dropout_rate
self.use_bias = use_bias
self.stacked = stacked
self.trainable_branch = trainable_branch
Expand Down Expand Up @@ -300,6 +303,10 @@ def build_branch_net(self):
activation=self.activation_branch,
trainable=self.trainable_branch,
)
if self.dropout_rate > 0:
y_func = tf.layers.dropout(
y_func, rate=self.dropout_rate, training=self.trainable_branch
lululxvi marked this conversation as resolved.
Show resolved Hide resolved
)
y_func = self._stacked_dense(
y_func,
1,
Expand All @@ -317,6 +324,10 @@ def build_branch_net(self):
regularizer=self.regularizer,
trainable=self.trainable_branch,
)
if self.dropout_rate > 0:
y_func = tf.layers.dropout(
y_func, rate=self.dropout_rate, training=self.trainable_branch
)
y_func = self._dense(
y_func,
self.layer_size_func[-1],
Expand All @@ -331,15 +342,22 @@ def build_trunk_net(self):
if self._input_transform is not None:
y_loc = self._input_transform(y_loc)
for i in range(1, len(self.layer_size_loc)):
trainable = (
self.trainable_trunk[i - 1]
if isinstance(self.trainable_trunk, (list, tuple))
else self.trainable_trunk
)
y_loc = self._dense(
y_loc,
self.layer_size_loc[i],
activation=self.activation_trunk,
regularizer=self.regularizer,
trainable=self.trainable_trunk[i - 1]
if isinstance(self.trainable_trunk, (list, tuple))
else self.trainable_trunk,
trainable=trainable,
lululxvi marked this conversation as resolved.
Show resolved Hide resolved
)
if self.dropout_rate > 0:
y_loc = tf.layers.dropout(
y_loc, rate=self.dropout_rate, training=trainable
)
return y_loc

def merge_branch_trunk(self, branch, trunk):
Expand Down Expand Up @@ -439,6 +457,7 @@ class DeepONetCartesianProd(NN):
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
net uses the activation `activation["trunk"]`, and the branch net uses
`activation["branch"]`.
dropout_rate (float): The dropout rate, between 0 and 1.
num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1,
`multi_output_strategy` below should be set.
multi_output_strategy (str or None): ``None``, "independent", "split_both", "split_branch" or
Expand Down Expand Up @@ -474,6 +493,7 @@ def __init__(
activation,
kernel_initializer,
regularization=None,
dropout_rate=0,
num_outputs=1,
multi_output_strategy=None,
):
Expand All @@ -487,6 +507,7 @@ def __init__(
self.activation_branch = self.activation_trunk = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.regularizer = regularizers.get(regularization)
self.dropout_rate = dropout_rate
self._inputs = None

self.num_outputs = num_outputs
Expand Down Expand Up @@ -553,6 +574,8 @@ def build_branch_net(self):
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
)
if self.dropout_rate > 0:
y_func = tf.layers.dropout(y_func, rate=self.dropout_rate)
lululxvi marked this conversation as resolved.
Show resolved Hide resolved
y_func = tf.layers.dense(
y_func,
self.layer_size_func[-1],
Expand All @@ -574,6 +597,8 @@ def build_trunk_net(self):
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
)
if self.dropout_rate > 0:
y_loc = tf.layers.dropout(y_loc, rate=self.dropout_rate)
return y_loc

def merge_branch_trunk(self, branch, trunk):
Expand Down