Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Mm/uq app #337

Open
wants to merge 36 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1972ef5
Initial add UQ to cmake
Sep 27, 2023
7f2651f
Initial UQ draft using f = l2norm, g = l1norm
Sep 27, 2023
6ee9885
Minimise posterior arguments using lambda w/ capture
Sep 27, 2023
5b8da12
Add yaml parsing first instance
Sep 28, 2023
58bacb8
move wavelet op construction into function
Sep 28, 2023
385f263
Move measurement op & input data setup out of main
Sep 28, 2023
5b75d7c
Refactor out measurement operator creation
Oct 1, 2023
d8a63ee
Move wavelet operator call closer to other setup calls
Oct 1, 2023
3832000
Refactor out save functions
Oct 1, 2023
e6d013b
Add out_path to parser to prevent inconsistencies
Oct 1, 2023
4d160cb
Returning a struct for consistency
Oct 1, 2023
9087faa
Move refactored fns into new (poorly named) file
Oct 1, 2023
0920091
Add Purify config file read to UQ
Oct 1, 2023
efcefa8
Add wavelet operator to prior
Oct 2, 2023
da5ef51
Update readme to include UQ
Oct 2, 2023
09be40c
Merge branch 'development' into mm/uq_app
Nov 6, 2024
693681a
Add generic cost function
Nov 7, 2024
9f0923b
Remove some confusing root-2s
Nov 13, 2024
f1a1982
Move non-templated functions to cpp file!
Nov 14, 2024
48d3dd8
Spacing
Nov 14, 2024
24a812c
Merge branch 'development' of github.com:astro-informatics/purify int…
Nov 21, 2024
9570311
Use non-greek sopt interface
Nov 21, 2024
7ec30b4
Lintin'
Nov 21, 2024
c0fe3fb
Merge branch 'development' into mm/uq_app
Nov 21, 2024
a0c7b92
Merge branch 'mm/update_sopt_interfacing' into mm/uq_app
Nov 21, 2024
1d35084
Unified approach for f/g function types
Nov 21, 2024
734abcf
Remove gProximalType references
Nov 21, 2024
773bd91
Move type map into types header
Nov 21, 2024
632ecb9
Add differentiable function to yaml
Nov 21, 2024
2dd9c4c
Optional yaml parsing
Nov 21, 2024
0c11af2
syntax errors
Nov 21, 2024
277b8ca
Add construction of f and g
Nov 21, 2024
6af78c7
Fix sigma shadowing
Nov 21, 2024
7bfd22b
Refactor function setup for re-use
Nov 21, 2024
f089c07
Default to avoid null pointers
Nov 21, 2024
65cdc87
add onnxrt to main
Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,34 @@ When `purify` runs a directory will be created, and the output images will be
saved and time-stamped. Additionally, a config file with the settings used will
be saved and time-stamped, helping for reproducibility and book-keeping.

Uncertainty Quantification
--------------------------

Bayesian hypothesis testing may be performed with the `purify_UQ` application, which can be located in the `build` directory.

This application takes a config yaml file with the following parameters:
- `confidence_interval` or `alpha`. (alpha = 1 - confidence interval.))
- `measurements_path`: path to measurements data (.vis file)
- `reference_image_path`: path to reference image i.e. output from purify run. (.fits files)
- `surrogate_image_path`: path to surrogate image i.e. doctored image with blurring or structural change that you want to test. (.fits file)
- `sigma`: standard deviation for Gaussian likelihood.
- `gamma`: multiplicative factor for prior.
- `purify_config`: path to purify config used to generate the reference image. **This should be used if available in order to ensure consistency of things like measurement and wavelet operators.**

You can then run the uncertainty quantification with the command:
```
purify_UQ <path to UQ_config yaml>
```

The application will report the value of the objective function for each image, the threshold value calculated from the reference image, and whether the surrogate image is ruled out or not.

Presently this is designed to work for the unconstrained problem where:
- The negative log-likelihood is a scaled L2-norm i.e. $ \frac{1}{2 \sigma^2} \sum (y_i - \Phi x_i)$ for some data $y$, image $x$, and measurement operator $\Phi$. (Equivalent to indepdendent multivariate Gaussian likelihood.)
- The negative log prior is a scaled L1-norm in _wavelet space_ i.e. $\gamma \sum (\Psi^\dag x)_i$ for some image $x$ and wavelet operator $\Psi$.
- The objective function is the sum of these two terms.

Docker
-------
## Debugging the CI workflow with tmate

The CI workflow has a manual dispatch trigger which allows you to log into the job while it's running. You can trigger it in
Expand Down
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ if (docs)
add_subdirectory(docs)
endif()

add_subdirectory(uncertainty_quantification)

add_executable(purify main.cc)
set_target_properties(purify PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR})
target_link_libraries(purify libpurify ${sopt_LIBRARIES})
Expand Down
429 changes: 67 additions & 362 deletions cpp/main.cc

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cpp/purify/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ set(HEADERS
set(SOURCES utilities.cc pfitsio.cc logging.cc
kernels.cc wproj_utilities.cc operators.cc uvfits.cc yaml-parser.cc
read_measurements.cc distribute.cc integration.cc wide_field_utilities.cc wkernel_integration.cc
wproj_operators.cc uvw_utilities.cc)
wproj_operators.cc uvw_utilities.cc setup_utils.cc)

if(PURIFY_CASACORE)
list(APPEND SOURCES casacore.cc)
Expand Down
49 changes: 24 additions & 25 deletions cpp/purify/algorithm_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,11 @@ namespace purify {
namespace factory {
enum class algorithm { padmm, primal_dual, sdmm, forward_backward };
enum class algo_distribution { serial, mpi_serial, mpi_distributed, mpi_random_updates };
enum class g_proximal_type { L1GProximal, TFGProximal, Indicator };
const std::map<std::string, algo_distribution> algo_distribution_string = {
{"none", algo_distribution::serial},
{"serial-equivalent", algo_distribution::mpi_serial},
{"random-updates", algo_distribution::mpi_random_updates},
{"fully-distributed", algo_distribution::mpi_distributed}};
const std::map<std::string, g_proximal_type> g_proximal_type_string = {
{"l1", g_proximal_type::L1GProximal}, {"learned", g_proximal_type::TFGProximal}};

//! return chosen algorithm given parameters
template <class Algorithm, class... ARGS>
Expand Down Expand Up @@ -80,7 +77,7 @@ padmm_factory(const algo_distribution dist,
.l1_proximal_positivity_constraint(positive_constraint)
.l1_proximal_real_constraint(real_constraint)
.lagrange_update_scale(0.9)
.nu(op_norm * op_norm)
.sq_op_norm(op_norm * op_norm)
.Psi(*wavelets)
.Phi(*measurements);
#ifdef PURIFY_MPI
Expand All @@ -90,10 +87,11 @@ padmm_factory(const algo_distribution dist,
switch (dist) {
case (algo_distribution::serial):
padmm
->gamma((wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval())
.cwiseAbs()
.maxCoeff() *
1e-3)
->regulariser_strength(
(wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval())
.cwiseAbs()
.maxCoeff() *
1e-3)
.l2ball_proximal_epsilon(epsilon)
.residual_tolerance(epsilon * residual_tolerance_scaling);
return padmm;
Expand Down Expand Up @@ -132,8 +130,8 @@ padmm_factory(const algo_distribution dist,
std::weak_ptr<Algorithm> const padmm_weak(padmm);
// set epsilon
padmm->residual_tolerance(epsilon * residual_tolerance_scaling).l2ball_proximal_epsilon(epsilon);
// set gamma
padmm->gamma(comm.all_reduce(
// set regulariser_strength
padmm->regulariser_strength(comm.all_reduce(
utilities::step_size<Vector<t_complex>>(uv_data.vis, measurements, wavelets, sara_size) *
1e-3,
MPI_MAX));
Expand Down Expand Up @@ -164,7 +162,7 @@ fb_factory(const algo_distribution dist,
const bool tight_frame = false, const t_real relative_variation = 1e-3,
const t_real l1_proximal_tolerance = 1e-2, const t_uint maximum_proximal_iterations = 50,
const t_real op_norm = 1, const std::string model_path = "",
const g_proximal_type g_proximal = g_proximal_type::L1GProximal,
const nondiff_func_type g_proximal = nondiff_func_type::L1Norm,
std::shared_ptr<DifferentiableFunc<typename Algorithm::Scalar>> f_function = nullptr) {
typedef typename Algorithm::Scalar t_scalar;
if (sara_size > 1 and tight_frame)
Expand All @@ -173,19 +171,19 @@ fb_factory(const algo_distribution dist,
"one wavelet basis.");
auto fb = std::make_shared<Algorithm>(uv_data.vis);
fb->itermax(max_iterations)
.gamma(reg_parameter)
.sigma(sigma * std::sqrt(2))
.beta(step_size * std::sqrt(2))
.regulariser_strength(reg_parameter)
.sigma(sigma)
.step_size(step_size)
.relative_variation(relative_variation)
.tight_frame(tight_frame)
.nu(op_norm * op_norm)
.sq_op_norm(op_norm * op_norm)
.Phi(*measurements);

if (f_function) fb->f_function(f_function); // only override f_function default if non-null
std::shared_ptr<NonDifferentiableFunc<t_scalar>> g;

switch (g_proximal) {
case (g_proximal_type::L1GProximal): {
case (nondiff_func_type::L1Norm): {
// Create a shared pointer to an instance of the L1GProximal class
// and set its properties
auto l1_gp = std::make_shared<sopt::algorithm::L1GProximal<t_scalar>>(false);
Expand All @@ -205,7 +203,7 @@ fb_factory(const algo_distribution dist,
g = l1_gp;
break;
}
case (g_proximal_type::TFGProximal): {
case (nondiff_func_type::Denoiser): {
#ifdef PURIFY_ONNXRT
// Create a shared pointer to an instance of the TFGProximal class
g = std::make_shared<sopt::algorithm::TFGProximal<t_scalar>>(model_path);
Expand All @@ -215,8 +213,8 @@ fb_factory(const algo_distribution dist,
"Type TFGProximal not recognized because purify was built with onnxrt=off");
#endif
}
case (g_proximal_type::Indicator): {
g = std::make_shared<RealIndicator<t_scalar>>();
case (nondiff_func_type::RealIndicator): {
g = std::make_shared<sopt::algorithm::RealIndicator<t_scalar>>();
break;
}
default: {
Expand Down Expand Up @@ -282,10 +280,11 @@ primaldual_factory(
switch (dist) {
case (algo_distribution::serial): {
primaldual
->gamma((wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval())
.cwiseAbs()
.maxCoeff() *
1e-3)
->regulariser_strength(
(wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval())
.cwiseAbs()
.maxCoeff() *
1e-3)
.l2ball_proximal_epsilon(epsilon)
.residual_tolerance(epsilon * residual_tolerance_scaling);
return primaldual;
Expand Down Expand Up @@ -345,8 +344,8 @@ primaldual_factory(
// set epsilon
primaldual->residual_tolerance(epsilon * residual_tolerance_scaling)
.l2ball_proximal_epsilon(epsilon);
// set gamma
primaldual->gamma(comm.all_reduce(
// set regulariser_strength
primaldual->regulariser_strength(comm.all_reduce(
utilities::step_size<Vector<t_complex>>(uv_data.vis, measurements, wavelets, sara_size) *
1e-3,
MPI_MAX));
Expand Down
26 changes: 26 additions & 0 deletions cpp/purify/pfitsio.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,30 @@ void write3d(const std::vector<Image<t_real>> &eigen_images, const std::string &
write3d(eigen_images, header, overwrite);
}

//! Read cube from fits file
std::vector<Image<t_complex>> read3d(const std::string &fits_name) {
std::vector<Image<t_complex>> eigen_images;
Vector<double> image;
int rows, cols, channels, pols = 1;
read3d<Vector<double>>(fits_name, image, rows, cols, channels, pols);
for (int i = 0; i < channels; i++) {
Vector<t_complex> eigen_image = Vector<t_complex>::Zero(rows * cols);
eigen_image.real() = image.segment(i * rows * cols, rows * cols);
eigen_images.push_back(Image<t_complex>::Map(eigen_image.data(), rows, cols));
}
return eigen_images;
}

//! Read image from fits file
Image<t_complex> read2d(const std::string &fits_name) {
/*
Reads in an image from a fits file and returns the image.

fits_name:: name of fits file
*/

const std::vector<Image<t_complex>> images = read3d(fits_name);
return images.at(0);
}

} // namespace purify::pfitsio
24 changes: 2 additions & 22 deletions cpp/purify/pfitsio.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,30 +322,10 @@ void read3d(const std::string &fits_name, Eigen::EigenBase<T> &output, int &rows
}

//! Read cube from fits file
std::vector<Image<t_complex>> read3d(const std::string &fits_name) {
std::vector<Image<t_complex>> eigen_images;
Vector<double> image;
int rows, cols, channels, pols = 1;
read3d<Vector<double>>(fits_name, image, rows, cols, channels, pols);
for (int i = 0; i < channels; i++) {
Vector<t_complex> eigen_image = Vector<t_complex>::Zero(rows * cols);
eigen_image.real() = image.segment(i * rows * cols, rows * cols);
eigen_images.push_back(Image<t_complex>::Map(eigen_image.data(), rows, cols));
}
return eigen_images;
}
std::vector<Image<t_complex>> read3d(const std::string &fits_name);

//! Read image from fits file
Image<t_complex> read2d(const std::string &fits_name) {
/*
Reads in an image from a fits file and returns the image.

fits_name:: name of fits file
*/

const std::vector<Image<t_complex>> images = read3d(fits_name);
return images.at(0);
}
Image<t_complex> read2d(const std::string &fits_name);

} // namespace purify::pfitsio

Expand Down
Loading
Loading