Skip to content

Commit

Permalink
Removes the positivity constraint on the weights of the skip connecti…
Browse files Browse the repository at this point in the history
…ons as it was unneccessary for maintaining convexity.
  • Loading branch information
jkeeley-MW committed Sep 25, 2024
1 parent 8b931f6 commit 886e93a
Show file tree
Hide file tree
Showing 85 changed files with 254 additions and 237 deletions.
27 changes: 11 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# AI Verification: Constrained Deep Learning [![Open in MATLAB Online](https://www.mathworks.com/images/responsive/global/open-in-matlab-online.svg)](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/constrained-deep-learning)
# AI Verification: Constrained Deep Learning

Constrained deep learning is an advanced approach to training deep neural networks by incorporating domain-specific constraints into the learning process. By integrating these constraints into the construction and training of neural networks, you can guarantee desirable behaviour in safety-critical scenarios where such guarantees are paramount.

This project aims to develop and evaluate deep learning models that adhere to predefined constraints, which could be in the form of physical laws, logical rules, or any other domain-specific knowledge. In the context of AI verification, constrained deep learning provides guarantees that certain desirable properties are present in the trained neural network by design. These desirable properties could include monotonicity, boundedness, and robustness amongst others.

<figure>
<p align="center">
<img src="./documentation/figures/constrained_learning.svg">
<img src="./documentation/figures/constrained_learning.svg"
style="width:4in;height:1.1in">
</p>
</figure>

Expand All @@ -32,12 +33,12 @@ The repository contains several introductory, interactive examples as well as lo

### Introductory Examples (Short)
Below are links for markdown versions of MATLAB Live Scripts that you can view in GitHub&reg;.
- [Fully Input Convex Neural Networks in 1-Dimension](examples/convex/introductory/PoC_Ex1_1DFICNN.md)
- [Fully Input Convex Neural Networks in n-Dimensions](examples/convex/introductory/PoC_Ex2_nDFICNN.md)
- [Partially Input Convex Neural Networks in n-Dimensions](examples/convex/introductory/PoC_Ex3_nDPICNN.md)
- [Fully Input Monotonic Neural Networks in 1-Dimension](examples/monotonic/introductory/PoC_Ex1_1DFMNN.md)
- [Fully Input Monotonic Neural Networks in n-Dimensions](examples/monotonic/introductory/PoC_Ex2_nDFMNN.md)
- [Lipschitz Continuous Neural Networks in 1-Dimension](examples/lipschitz/introductory/PoC_Ex1_1DLNN.md)
- [Fully input convex neural networks in 1-dimension](examples/convex/introductory/PoC_Ex1_1DFICNN.md)
- [Fully input convex neural networks in n-dimensions](examples/convex/introductory/PoC_Ex2_nDFICNN.md)
- [Partially input convex neural networks in n-dimensions](examples/convex/introductory/PoC_Ex3_nDPICNN.md)
- [Fully input monotonic neural networks in 1-dimension](examples/monotonic/introductory/PoC_Ex1_1DFMNN.md)
- [Fully input monotonic neural networks in n-dimensions](examples/monotonic/introductory/PoC_Ex2_nDFMNN.md)
- [Lipschitz continuous neural networks in 1-dimensions](examples/lipschitz/introductory/PoC_Ex1_1DLNN.md)

These examples make use of [custom training loops](https://uk.mathworks.com/help/deeplearning/deep-learning-custom-training-loops.html) and the [`arrayDatastore`](https://uk.mathworks.com/help/matlab/ref/matlab.io.datastore.arraydatastore.html) object. To learn more, click the links.

Expand Down Expand Up @@ -70,13 +71,7 @@ As discussed in [1] (see 3.4.1.5), in certain situations, small violations in th

## Technical Articles

This repository focuses on the development and evaluation of deep learning models that adhere to constraints crucial for safety-critical applications, such as predictive maintenance for industrial machinery and equipment. Specifically, it focuses on enforcing monotonicity, convexity, and Lipschitz continuity within neural networks to ensure predictable and controlled behavior.

By emphasizing constraints like monotonicity, constrained neural networks ensure that predictions of the Remaining Useful Life (RUL) of components behave intuitively: as a machine's condition deteriorates, the estimated RUL should monotonically decrease. This is crucial in applications like aerospace or manufacturing, where an accurate and reliable estimation of RUL can prevent failures and save costs.

Alongside monotonicity, Lipschitz continuity is also enforced to guarantee model robustness and controlled behavior. This is essential in environments where safety and precision are paramount such as control systems in autonomous vehicles or precision equipment in healthcare.

Convexity is especially beneficial for control systems as it inherently provides boundedness properties. For instance, by ensuring that the output of a neural network lies within a convex hull, it is possible to guarantee that the control commands remain within a safe and predefined operational space, preventing erratic or unsafe system behaviors. This boundedness property, derived from the convex nature of the model's output space, is critical for maintaining the integrity and safety of control systems under various conditions.
This repository focuses on the development and evaluation of deep learning models that adhere to constraints crucial for safety-critical applications, such as predictive maintenance for industrial machinery and equipment. Specifically, it focuses on enforcing monotonicity, convexity, and Lipschitz continuity within neural networks to ensure predictable and controlled behavior. By emphasizing constraints like monotonicity, constrained neural networks ensure that predictions of the Remaining Useful Life (RUL) of components behave intuitively: as a machine's condition deteriorates, the estimated RUL should monotonically decrease. This is crucial in applications like aerospace or manufacturing, where an accurate and reliable estimation of RUL can prevent failures and save costs. Alongside monotonicity, Lipschitz continuity is also enforced to guarantee model robustness and controlled behavior. This is essential in environments where safety and precision are paramount such as control systems in autonomous vehicles or precision equipment in healthcare. Convexity is especially beneficial for control systems as it inherently provides boundedness properties. For instance, by ensuring that the output of a neural network lies within a convex hull, it is possible to guarantee that the control commands remain within a safe and predefined operational space, preventing erratic or unsafe system behaviors. This boundedness property, derived from the convex nature of the model's output space, is critical for maintaining the integrity and safety of control systems under various conditions.

These technical articles explain key concepts of AI verification in the context of constrained deep learning. They include discussions on how to achieve the specified constraints in neural networks at construction and training time, as well as deriving and proving useful properties of constrained networks in AI verification applications. It is not necessary to go through these articles in order to explore this repository, however, you can find references and more in depth discussion here.

Expand All @@ -90,4 +85,4 @@ These technical articles explain key concepts of AI verification in the context
- [3] Gouk, Henry, et al. “Regularisation of Neural Networks by Enforcing Lipschitz Continuity.” Machine Learning, vol. 110, no. 2, Feb. 2021, pp. 393–416. DOI.org (Crossref), https://doi.org/10.1007/s10994-020-05929-w
- [4] Kitouni, Ouail, et al. Expressive Monotonic Neural Networks. arXiv:2307.07512, arXiv, 14 July 2023. arXiv.org, http://arxiv.org/abs/2307.07512.

Copyright 2024, The MathWorks, Inc.
Copyright (c) 2024, The MathWorks, Inc.
14 changes: 7 additions & 7 deletions conslearn/+conslearn/+convex/buildFICNN.m
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
%
% BUILDFICNN name-value arguments:
%
% 'PositiveNonDecreasingActivation' - Specify the positive, convex,
% 'ConvexNonDecreasingActivation' - Specify the convex,
% non-decreasing activation functions.
% The options are 'softplus' or 'relu'.
% The default is 'softplus'.
%
% The construction of this network corresponds to Eq 2 in [1] with the
% exception that the application of the positive, non-decreasing activation
% exception that the application of the convex, non-decreasing activation
% function on the network output is not applied. This maintains convexity
% but permits positive and negative network outputs.
%
Expand All @@ -31,7 +31,7 @@
arguments
inputSize (1,:)
numHiddenUnits (1,:)
options.PositiveNonDecreasingActivation = 'softplus'
options.ConvexNonDecreasingActivation = 'softplus'
end

% Construct the correct input layer
Expand All @@ -43,7 +43,7 @@
end

% Loop over construction of hidden units
switch options.PositiveNonDecreasingActivation
switch options.ConvexNonDecreasingActivation
case 'relu'
pndFcn = @(k)reluLayer(Name="pnd_" + k);
case 'softplus'
Expand All @@ -68,10 +68,10 @@

% Add a cascading residual connection
for ii = 2:depth
tempLayers = fullyConnectedLayer(numHiddenUnits(ii),Name="fc_y_+_" + ii);
tempLayers = fullyConnectedLayer(numHiddenUnits(ii),Name="fc_y_" + ii);
lgraph = addLayers(lgraph,tempLayers);
lgraph = connectLayers(lgraph,"input","fc_y_+_" + ii);
lgraph = connectLayers(lgraph,"fc_y_+_" + ii,"add_" + ii + "/in2");
lgraph = connectLayers(lgraph,"input","fc_y_" + ii);
lgraph = connectLayers(lgraph,"fc_y_" + ii,"add_" + ii + "/in2");
end

% Initialize dlnetwork
Expand Down
8 changes: 4 additions & 4 deletions conslearn/+conslearn/+convex/buildPICNN.m
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
%
% BUILDPICNN name-value arguments:
%
% 'PositiveNonDecreasingActivation' - Specify the positive, convex,
% 'ConvexNonDecreasingActivation' - Specify the convex,
% non-decreasing activation functions.
% The options are 'softplus' or 'relu'.
% The default is 'softplus'.
Expand All @@ -32,7 +32,7 @@
% default value is 1.
%
% The construction of this network corresponds to Eq 3 in [1] with the
% exception that the application of the positive, non-decreasing activation
% exception that the application of the convex, non-decreasing activation
% function on the network output is not applied. This maintains convexity
% but permits positive and negative network outputs. Additionally, and in
% keeping with the notation used in the reference, in this implementation
Expand All @@ -50,7 +50,7 @@
arguments
inputSize (1,:) {iValidateInputSize(inputSize)}
numHiddenUnits (1,:)
options.PositiveNonDecreasingActivation = 'softplus'
options.ConvexNonDecreasingActivation = 'softplus'
options.Activation = 'tanh'
options.ConvexChannelIdx = 1
end
Expand All @@ -63,7 +63,7 @@
convexInputSize = numel(convexChannels);

% Prepare the two types of valid activation functions
switch options.PositiveNonDecreasingActivation
switch options.ConvexNonDecreasingActivation
case 'relu'
pndFcn = @(k)reluLayer(Name="pnd_" + k);
case 'softplus'
Expand Down
24 changes: 12 additions & 12 deletions conslearn/buildConstrainedNetwork.m
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
%
% These options and default values apply to convex constrained networks:
%
% PositiveNonDecreasingActivation - Positive, convex, non-decreasing
% ConvexNonDecreasingActivation - Convex, non-decreasing
% ("fully-convex") activation functions.
% ("partially-convex") The options are "softplus" or "relu".
% The default is "softplus".
Expand Down Expand Up @@ -96,10 +96,10 @@
iValidateInputSize(inputSize)}
numHiddenUnits (1,:) {mustBeInteger,mustBeReal,mustBePositive}
% Convex
options.PositiveNonDecreasingActivation {...
options.ConvexNonDecreasingActivation {...
mustBeTextScalar, ...
mustBeMember(options.PositiveNonDecreasingActivation,["relu","softplus"]),...
iValidateConstraintWithPositiveNonDecreasingActivation(options.PositiveNonDecreasingActivation, constraint)}
mustBeMember(options.ConvexNonDecreasingActivation,["relu","softplus"]),...
iValidateConstraintWithConvexNonDecreasingActivation(options.ConvexNonDecreasingActivation, constraint)}
options.ConvexChannelIdx (1,:) {...
iValidateConstraintWithConvexChannelIdx(options.ConvexChannelIdx, inputSize, constraint), ...
mustBeNumeric,mustBePositive,mustBeInteger}
Expand Down Expand Up @@ -131,15 +131,15 @@
switch constraint
case "fully-convex"
% Set defaults
if ~any(fields(options) == "PositiveNonDecreasingActivation")
options.PositiveNonDecreasingActivation = "softplus";
if ~any(fields(options) == "ConvexNonDecreasingActivation")
options.ConvexNonDecreasingActivation = "softplus";
end
net = conslearn.convex.buildFICNN(inputSize, numHiddenUnits, ...
PositiveNonDecreasingActivation=options.PositiveNonDecreasingActivation);
ConvexNonDecreasingActivation=options.ConvexNonDecreasingActivation);
case "partially-convex"
% Set defaults
if ~any(fields(options) == "PositiveNonDecreasingActivation")
options.PositiveNonDecreasingActivation = "softplus";
if ~any(fields(options) == "ConvexNonDecreasingActivation")
options.ConvexNonDecreasingActivation = "softplus";
end
if ~any(fields(options) == "Activation")
options.Activation = "tanh";
Expand All @@ -148,7 +148,7 @@
options.ConvexChannelIdx = 1;
end
net = conslearn.convex.buildPICNN(inputSize, numHiddenUnits,...
PositiveNonDecreasingActivation=options.PositiveNonDecreasingActivation,...
ConvexNonDecreasingActivation=options.ConvexNonDecreasingActivation,...
Activation=options.Activation,...
ConvexChannelIdx=options.ConvexChannelIdx);
case "fully-monotonic"
Expand Down Expand Up @@ -259,9 +259,9 @@ function iValidateConstraintWithMonotonicTrend(param, constraint)
end
end

function iValidateConstraintWithPositiveNonDecreasingActivation(param, constraint)
function iValidateConstraintWithConvexNonDecreasingActivation(param, constraint)
if ( ~isequal(constraint, "fully-convex") && ~isequal(constraint,"partially-convex") ) && ~isempty(param)
error("'PositiveNonDecreasingActivation' is not an option for constraint " + constraint);
error("'ConvexNonDecreasingActivation' is not an option for constraint " + constraint);
end
end

Expand Down
10 changes: 10 additions & 0 deletions conslearn/trainConstrainedNetwork.m
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@
end
end
end

% Update the training monitor status
if trainingOptions.TrainingMonitor
if monitor.Stop == 1
monitor.Status = "Training stopped";
else
monitor.Status = "Training complete";
end
end

end

%% Helpers
Expand Down
Loading

0 comments on commit 886e93a

Please sign in to comment.