Skip to content

Commit

Permalink
Tensorflow 1.x backend: Unify argument names (#1887)
Browse files Browse the repository at this point in the history
  • Loading branch information
vl-dud authored Nov 22, 2024
1 parent 9ba2019 commit bb1d3ac
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions deepxde/nn/tensorflow_compat_v1/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ class DeepONet(NN):
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
The list length should match the length of `layer_size_trunk` - 1 for the
trunk net and `layer_size_branch` - 2 for the branch net.
The list length should match the length of `layer_sizes_trunk` - 1 for the
trunk net and `layer_sizes_branch` - 2 for the branch net.
trainable_branch: Boolean.
trainable_trunk: Boolean or a list of booleans.
num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1,
Expand Down Expand Up @@ -210,7 +210,7 @@ def __init__(
super().__init__()
if isinstance(trainable_trunk, (list, tuple)):
if len(trainable_trunk) != len(layer_sizes_trunk) - 1:
raise ValueError("trainable_trunk does not match layer_size_trunk.")
raise ValueError("trainable_trunk does not match layer_sizes_trunk.")

self.layer_size_func = layer_sizes_branch
self.layer_size_loc = layer_sizes_trunk
Expand Down Expand Up @@ -490,11 +490,11 @@ class DeepONetCartesianProd(NN):
"""Deep operator network for dataset in the format of Cartesian product.
Args:
layer_size_branch: A list of integers as the width of a fully connected network,
layer_sizes_branch: A list of integers as the width of a fully connected network,
or `(dim, f)` where `dim` is the input dimension and `f` is a network
function. The width of the last layer in the branch and trunk net
should be the same for all strategies except "split_branch" and "split_trunk".
layer_size_trunk (list): A list of integers as the width of a fully connected
layer_sizes_trunk (list): A list of integers as the width of a fully connected
network.
activation: If `activation` is a ``string``, then the same activation is used in
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
Expand All @@ -505,8 +505,8 @@ class DeepONetCartesianProd(NN):
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
The list length should match the length of `layer_size_trunk` - 1 for the
trunk net and `layer_size_branch` - 2 for the branch net.
The list length should match the length of `layer_sizes_trunk` - 1 for the
trunk net and `layer_sizes_branch` - 2 for the branch net.
num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1,
`multi_output_strategy` below should be set.
multi_output_strategy (str or None): ``None``, "independent", "split_both", "split_branch" or
Expand Down Expand Up @@ -537,8 +537,8 @@ class DeepONetCartesianProd(NN):

def __init__(
self,
layer_size_branch,
layer_size_trunk,
layer_sizes_branch,
layer_sizes_trunk,
activation,
kernel_initializer,
regularization=None,
Expand All @@ -547,8 +547,8 @@ def __init__(
multi_output_strategy=None,
):
super().__init__()
self.layer_size_func = layer_size_branch
self.layer_size_loc = layer_size_trunk
self.layer_size_func = layer_sizes_branch
self.layer_size_loc = layer_sizes_trunk
if isinstance(activation, dict):
self.activation_branch = activations.get(activation["branch"])
self.activation_trunk = activations.get(activation["trunk"])
Expand All @@ -562,24 +562,24 @@ def __init__(
else:
self.dropout_rate_branch = self.dropout_rate_trunk = dropout_rate
if isinstance(self.dropout_rate_branch, list):
if not (len(layer_size_branch) - 2) == len(self.dropout_rate_branch):
if not (len(layer_sizes_branch) - 2) == len(self.dropout_rate_branch):
raise ValueError(
"Number of dropout rates of branch net must be "
f"equal to {len(layer_size_branch) - 2}"
f"equal to {len(layer_sizes_branch) - 2}"
)
else:
self.dropout_rate_branch = [self.dropout_rate_branch] * (
len(layer_size_branch) - 2
len(layer_sizes_branch) - 2
)
if isinstance(self.dropout_rate_trunk, list):
if not (len(layer_size_trunk) - 1) == len(self.dropout_rate_trunk):
if not (len(layer_sizes_trunk) - 1) == len(self.dropout_rate_trunk):
raise ValueError(
"Number of dropout rates of trunk net must be "
f"equal to {len(layer_size_trunk) - 1}"
f"equal to {len(layer_sizes_trunk) - 1}"
)
else:
self.dropout_rate_trunk = [self.dropout_rate_trunk] * (
len(layer_size_trunk) - 1
len(layer_sizes_trunk) - 1
)
self._inputs = None

Expand Down

0 comments on commit bb1d3ac

Please sign in to comment.