O2O: ODE-based learning to optimize (O2O) is a comprehensive framework integrating the inertial system with Hessian-driven damping (ISHD) for developing optimization methods. We focus on the uncontrained convex smooth optimization:
The pipeline of O2O is as follows:
The key ingredients of the proposed pipeline are:
-
Convergence guarantee: Ensures the stable discretization of the ODE using the forward Euler scheme, consisting of a convergence condition and a stability condition.
-
Learning to optimize framework: Utilizes a learning to optimize framework and a corresponding algorithm to find the optimal coefficients numerically.
-
Stopping time: A measure of the efficiency of the algorithm generated by discretizing the ODE, generalizing complexity from discrete-time to continuous-time cases.
-
Probability distribution of a parameterized function family: Defines the probability distribution by establishing equivalence with corresponding parameters.
-
Stochastic penalty algorithm (StoPM): Solves the L2O problem using penalty function method and stochastic optimization algorithms. We provide convergence guarantees for StoPM under the sufficient decrease assumption, using only the conservative gradients, which makes the algorithm more robust and general.
├── classic_optimizer # classic first-order methods for comparison
│ Â
├── dataset # datasets that are used to generate training and testing functions
│ Â
├── problem # contain the logistic regression and lpp norm minimization
│ Â
├── run.sh # scripts to run different experiments
│ Â
├── train.py # main script to train the model
│ Â
├── vector_field # vector fields for different ODEs
│ Â
└── visualization
├── calc_complexity.py # calculate the averaged complexity of different algorithms for certain optimization problems
│ Â
├── calc_lastgrad.py # calculate the averaged gradient norm at last iteration of different algorithms for certain optimization problems
│ Â
├── compare_diff_epochs.py # compare the performance of the learned algorithm at different epochs
│ Â
├── generate_test_result.py # use the trained coefficents to generate test results
│ Â
├── table_complexity.py # organize the result generated by calc_complexity.py to a table (Table 4,5 in the paper)
│ Â
├── table_lastgrad.py # organize the result generated by calc_complexity.py to a table (Table 2,3 in the paper)
│ Â
├── visualize_diff_epochs.py # visualize the performance of the learned algorithm at different epochs (Fig. 4 in the paper)
│ Â
└── visualize_test_result.py # visualize the test results (Fig. 5,6,7 in the paper)
The environment dependencies are exported in the form of "requirements.yaml". For the most convenient installation of these environments, we highly recommend using conda.
conda env create -f requirements.yaml
Run the following commands to train the model in different optimization tasks.
python train.py --problem lpp --dataset a5a --num_epoch 80 --pen_coeff 0.5
python train.py --problem lpp --dataset separable --num_epoch 15 --pen_coeff 0.5 --eps 1e-4
python train.py --problem lpp --dataset covtype --num_epoch 50 --pen_coeff 0.5 --batch_size 10240
python train.py --problem logistic --dataset covtype --num_epoch 50 --pen_coeff 0.5 --batch_size 10240
python train.py --problem lpp --dataset w3a --num_epoch 100 --pen_coeff 0.5
When the training process is finished in specific task, run the files in visualization to generate the test results and visualize them. The trained model and checkpoints are saved in the folder "saved_models" and "checkpoints" under "train_log". You need to specify the path of the model and checkpoints when run the visualization scripts.
usage: python train.py [-h] [--problem {logistic,lpp}] [--dataset {mushrooms,a5a,w3a,phishing,separable,covtype}] [--pretrain]
[--num_epoch NUM_EPOCH] [--pen_coeff PEN_COEFF] [--lr LR] [--momentum MOMENTUM] [--batch_size BATCH_SIZE]
[--seed SEED] [--init_it INIT_IT] [--discrete_stepsize DISCRETE_STEPSIZE] [--eps EPS] [--l2 L2] [--p P]
[--optim {SGD,Adam}] [--threshold THRESHOLD]
Train the neural ODE using exact L1 penalty method.
optional arguments:
-h, --help show this help message and exit
--problem {logistic,lpp}
Either logistic regression (default: without L2 regularization) or Lpp minimization (default: p=4)
--dataset {mushrooms,a5a,w3a,phishing,separable,covtype}
Dataset use for training
--pretrain Load the pre-trained model or not
--num_epoch NUM_EPOCH
The number of the training epoch
--pen_coeff PEN_COEFF
The penalty coefficient of the L1 exact penalty term
--lr LR Learning rate of SGD
--momentum MOMENTUM Momentum coefficient of SGD
--batch_size BATCH_SIZE
Batch size for training, default 1024, 10240 is recommended for covtype
--seed SEED Random seed for reproducing. 3407 is all you need
--init_it INIT_IT The number of iterate used to initialize the neural ODE, default is 300
--discrete_stepsize DISCRETE_STEPSIZE
the step size used in discretization, default is 0.04
--eps EPS epsilon used to define the stopping time
--l2 L2 the coefficient of the L2 regularization term in logistic regression
--p P the exponential index of lpp minimization
--optim {SGD,Adam} the optimizer using in training
--threshold THRESHOLD
the threshold using in constraints
The datasets used in our experiments are summarized in Table 1. In this table,
Dataset | Separable | References | |||
---|---|---|---|---|---|
a5a |
No | [Dua and Graff, 2019] | |||
w3a |
No | [Platt, 1998] | |||
mushrooms |
Yes | [Dua and Graff, 2019] | |||
covtype |
No | [Dua and Graff, 2019] | |||
phishing |
No | [Dua and Graff, 2019] | |||
separable |
Yes | [Wilson et al., 2019] |
Table 1: A summary of the datasets used in experiments.
All the datasets are designed for binary classification problems, and downloaded from the LIBSVM data, except the separable
dataset. We construct the separable
dataset using the code snippet downloaded from Wilson et al., 2019. They are generated by sampling
For each dataset, the label of each sample belongs to a5a
and w3a
. For datasets that do not specify the testing set and training set, we divide them manually.
We set
-
GD. The vanilla gradient descent GD is the standard method in optimization. We set the stepsize as
$h=1/L$ . -
NAG. Nesterov's accelerated gradient descent method NAG is a milestone of the acceleration methods. We employ the version for convex functions
$$y_{k+1}=x_{k}-h\nabla f(x_{k}),\quad x_{k+1}=y_{k+1}+\frac{k-1}{k+2}(y_{k+1}-y_{k}),$$ where the stepsize is chosen as
$h=1/L$ . -
IGAHD. Inertial gradient algorithm with Hessian-driven damping. This method is obtained by applying a NAG inspired time discretization of
$$\ddot{x}(t)+\frac{\alpha}{t}\dot{x}(t)+\beta\nabla^2 f(x(t))\dot{x}(t)+\left(1+\frac{\beta}{t}\right)\nabla f(x(t))=0.$$ Let$s=1/L$ . In each iteration, setting$\alpha_{k}=1-\alpha/k$ , the method performs$$y_k=x_k+\alpha_k\left(x_k-x_{k-1}\right)-\beta \sqrt{s}\left(\nabla f\left(x_k\right)-\nabla f\left(x_{k-1}\right)\right)-\frac{\beta \sqrt{s}}{k} \nabla f\left(x_{k-1}\right),\quad x_{k+1}=y_k-s \nabla f\left(y_k\right).$$ In Attouch et al. 2020, it has been show that IGAHD owns$\mathcal{O}(1/k^2)$ convergence rate when$0\leq \beta< 2/\sqrt{s}$ and$s\leq 1/L$ . Its performance may not coincide with NAG due to the existence of the gradient correction term. In our experiments, IGAHD serves as a baseline of the optimization methods derived from the ODE viewpoint without learning. -
EIGAC. Explicit inertial gradient algorithm with correction (EIGAC). We provide two versions of EIGAC with default coefficients
$\alpha = 6$ ,$\beta(t) = \left({4}/{h} - {2\alpha}/{t}\right)/L$ , and$\beta(t) = h\gamma(t)$ and the coefficients learned by O2O. The numerical experiments effectively show that the EIGAC with default coefficients are sufficient to converge and the performance is comparable with NAG, while EIGAC with learned coefficients is superior over other methods.$$\frac{x_{k+1}-x_{k}}{h}=v_{k}-\beta(t_{k})\nabla f(x_{k}),\quad \frac{v_{k+1}-v_{k}}{h}=-\frac{\alpha}{t}\left(v_{k}-\beta(t_{k})\nabla f(x_{k})\right)+(\dot{\beta}(t_{k})-\gamma(t_{k}))\nabla f(x_{k}).$$
We empirically compare these methods on two tasks: logistic regression and minimization of
In our second task, given an even integer
We randomly generate 100 test functions for each problem, varying the instances from the dataset. The problems are specified by the dataset, batch size, and formulation (e.g., lpp_a5a
for a5a
dataset and
The averaged performance measure is a metric that evaluates the effectiveness of a method in minimizing the objective functions in the test set
The averaged performance measure is reported in Tables 2 and 3. EIGAC outperforms other methods with at least a magnitude in most cases.
Method | mushrooms | a5a | w3a | phishing | covtype | separable |
---|---|---|---|---|---|---|
GD | -1.55 | -1.81 | -1.90 | -1.35 | -1.89 | -1.56 |
NAG | -3.37 | -3.11 | -3.26 | -3.01 | -3.07 | -3.66 |
EIGAC(initial) | -3.02 | -2.97 | -3.02 | -2.80 | -3.48 | -3.32 |
IGAHD | -3.02 | -2.97 | -3.02 | -2.80 | -3.48 | -3.31 |
EIGAC(learned) | -4.83 | -4.38 | -4.46 | -4.82 | -4.37 | -5.49 |
Table 2: Averaged performance measure in logistic regression problems.
Method | mushrooms | a5a | w3a | phishing | covtype | separable |
---|---|---|---|---|---|---|
GD | -2.49 | -2.79 | -3.18 | -2.36 | -2.65 | -2.95 |
NAG | -4.35 | -4.19 | -4.72 | -4.06 | -4.43 | -6.11 |
EIGAC(initial) | -4.16 | -3.99 | -4.66 | -4.37 | -4.47 | -6.15 |
IGAHD | -4.16 | -4.05 | -4.66 | -4.37 | -4.51 | -6.14 |
EIGAC(learned) | -5.27 | -5.11 | -5.71 | -5.65 | -5.14 | -7.55 |
Table 3: Averaged performance measure in
The averaged complexity is a metric that measures the computational efficiency of a method in reaching a desired level of accuracy. It is calculated by taking the average of the complexity measure
The averaged complexity is presented in Tables 4 and 5. EIGAC consistently improves complexity, requiring only half the iterations of other methods in most problems.
Method | mushrooms | a5a | w3a | phishing | covtype | separable |
---|---|---|---|---|---|---|
GD | 500.00 | 500.00 | 500.00 | 500.00 | 500.00 | 500.00 |
NAG | 500.00 | 500.00 | 500.00 | 500.00 | 500.00 | 424.71 |
EIGAC(initial) | 500.00 | 500.00 | 500.00 | 500.00 | 497.12 | 500.00 |
IGAHD | 500.00 | 500.00 | 500.00 | 500.00 | 497.36 | 500.00 |
EIGAC(learned) | 153.48 | 227.32 | 216.42 | 182.15 | 237.88 | 11.49 |
Table 4: Averaged complexity in logistic regression problems.
Method | mushrooms | a5a | w3a | phishing | covtype | separable |
---|---|---|---|---|---|---|
GD | 500.00 | 500.00 | 500.00 | 500.00 | 500.00 | 500.00 |
NAG | 183.77 | 211.87 | 92.61 | 252.43 | 167.31 | 52.15 |
EIGAC(initial) | 235.07 | 245.68 | 96.17 | 224.36 | 203.02 | 22.10 |
IGAHD | 235.84 | 239.49 | 96.03 | 224.53 | 204.12 | 29.98 |
EIGAC(learned) | 93.12 | 122.16 | 50.93 | 85.57 | 109.15 | 11.00 |
Table 5: Averaged complexity in
We hope that the package is useful for your application. If you have any bug reports or comments, please feel free to email one of the toolbox authors:
- Zhonglin Xie, [email protected]
- Zaiwen Wen, [email protected]
Zhonglin Xie, Wotao Yin, Zaiwen Wen, ODE-based Learning to Optimize, arXiv:2406.02006, 2024.
GNU General Public License v3.0