From 1972ef5248c38f442c725a6204765d7fbb60afdc Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Wed, 27 Sep 2023 20:36:03 +0100 Subject: [PATCH 01/32] Initial add UQ to cmake --- cpp/CMakeLists.txt | 2 ++ cpp/uncertainty_quantification/CMakeLists.txt | 9 +++++++++ cpp/uncertainty_quantification/uq_main.cc | 10 ++++++++++ 3 files changed, 21 insertions(+) create mode 100644 cpp/uncertainty_quantification/CMakeLists.txt create mode 100644 cpp/uncertainty_quantification/uq_main.cc diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4553b6f96..cb00c384b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -45,6 +45,8 @@ if (docs) add_subdirectory(docs) endif() +add_subdirectory(uncertainty_quantification) + add_executable(purify main.cc) target_link_libraries(purify libpurify) set_target_properties(purify PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) diff --git a/cpp/uncertainty_quantification/CMakeLists.txt b/cpp/uncertainty_quantification/CMakeLists.txt new file mode 100644 index 000000000..5c2e7918d --- /dev/null +++ b/cpp/uncertainty_quantification/CMakeLists.txt @@ -0,0 +1,9 @@ +add_executable(purify_UQ uq_main.cc) +target_link_libraries(purify_UQ libpurify) +set_target_properties(purify_UQ PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) + +install(TARGETS purify_UQ + EXPORT PurifyTargets + DESTINATION share/cmake/Purify + RUNTIME DESTINATION bin + ) \ No newline at end of file diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc new file mode 100644 index 000000000..78846fc6a --- /dev/null +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -0,0 +1,10 @@ +#include + +int main(int argc, char **argv) +{ + std::cout << "Uncertainty Quantification." << std::endl; + + + + return 0; +} \ No newline at end of file From 7f2651f0e2516f8925c066b224e553b311da17fe Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Wed, 27 Sep 2023 22:27:27 +0100 Subject: [PATCH 02/32] Initial UQ draft using f = l2norm, g = l1norm --- cpp/uncertainty_quantification/uq_main.cc | 110 +++++++++++++++++++++- 1 file changed, 109 insertions(+), 1 deletion(-) diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index 78846fc6a..84b21fd8b 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -1,10 +1,118 @@ #include +#include "purify/pfitsio.h" +#include "purify/utilities.h" +#include "purify/measurement_operator_factory.h" +#include "sopt/objective_functions.h" +#include + +using VectorC = sopt::Vector>; + +double Posterior(const VectorC &image, + const VectorC &measurements, + const sopt::LinearTransform &measurement_operator, + const double sigma, + const double gamma) +{ + const auto residuals = (measurement_operator * image) - measurements; + return residuals.squaredNorm() / (2 * sigma * sigma) + image.cwiseAbs().sum() * gamma; +} int main(int argc, char **argv) { - std::cout << "Uncertainty Quantification." << std::endl; + if(argc != 7) + { + std::cout << "Please provide the following six arguments: " << std::endl; + std::cout << "Path for measurement data." << std::endl; + std::cout << "Path for reference image (.fits file)." << std::endl; + std::cout << "Path for surrogate iamge (.fits file)." << std::endl; + std::cout << "Confidence interval." << std::endl; + std::cout << "sigma (Gaussian Likelihood parameter)." << std::endl; + std::cout << "gamma (scaling of L1-norm prior)." << std::endl; + return 1; + } + + const std::string measurements_path = argv[1]; + const std::string ref_image_path = argv[2]; + const std::string surrogate_image_path = argv[3]; + const double confidence = strtod(argv[4], nullptr); + const double alpha = 1 - confidence; + const double sigma = strtod(argv[5], nullptr); + const double gamma = strtod(argv[6], nullptr); + + const purify::utilities::vis_params measurement_data = purify::utilities::read_visibility(measurements_path, false); + + const auto reference_image = purify::pfitsio::read2d(ref_image_path); + const VectorC reference_vector = VectorC::Map(reference_image.data(), reference_image.size()); + const auto surrogate_image = purify::pfitsio::read2d(surrogate_image_path); + const VectorC surrogate_vector = VectorC::Map(surrogate_image.data(), surrogate_image.size()); + + const uint imsize_x = reference_image.cols(); + const uint imsize_y = reference_image.rows(); + if((imsize_x != surrogate_image.cols()) || (imsize_y != surrogate_image.rows())) + { + std::cout << "Surrogate and reference images have different dimensions. Aborting." << std::endl; + return 2; + } + + // This is the measurement operator used in the test but this should probably be selectable + auto const measurement_operator = + purify::factory::measurement_operator_factory>>( + purify::factory::distributed_measurement_operator::serial, + measurement_data, + imsize_y, + imsize_x, + 1, + 1, + 2, + purify::kernels::kernel_from_string.at("kb"), + 4, + 4 + ); + + if( ((*measurement_operator) * reference_vector).size() != measurement_data.vis.size()) + { + std::cout << "Image size is not compatible with the measurement operator and data provided." << std::endl; + return 3; + } + + // Calculate the posterior function for the reference image + // posterior = likelihood + prior + // Likelihood = |y - Phi(x)|^2 / sigma^2 (L2 norm) + // Prior = Sum(|x_i|) * gamma (L1 norm) + //auto Posterior = [&measurement_data, measurement_operator, sigma, + // gamma](const VectorC &image) { + // { + // + // } + //}; + + const double reference_posterior = Posterior(reference_vector, + measurement_data.vis, + *measurement_operator, + sigma, + gamma); + const double surrogate_posterior = Posterior(surrogate_vector, + measurement_data.vis, + *measurement_operator, + sigma, + gamma); + + // Threshold for surrogate image posterior to be within confidence limit + const double N = imsize_x * imsize_y; + const double tau = std::sqrt( 16 * std::log(3 / alpha) ); + const double threshold = reference_posterior + tau * std::sqrt(N) + N; + + std::cout << "Uncertainty Quantification." << std::endl; + std::cout << "Reference Log Posterior = " << reference_posterior << std::endl; + std::cout << "Confidence interval = " << confidence << std::endl; + std::cout << "Log Posterior threshold = " << threshold << std::endl; + std::cout << "Surrogate Log Posterior = " << surrogate_posterior << std::endl; + std::cout << "Surrogate image is " + << ((surrogate_posterior <= threshold) ? "within the credible interval." + : "excluded by the credible interval.") + << std::endl; return 0; } \ No newline at end of file From 6ee9885249f7495567e9c76259237c411bc37e3f Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Wed, 27 Sep 2023 22:31:38 +0100 Subject: [PATCH 03/32] Minimise posterior arguments using lambda w/ capture --- cpp/uncertainty_quantification/uq_main.cc | 34 ++++++----------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index 84b21fd8b..2eee8b1e9 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -7,16 +7,6 @@ using VectorC = sopt::Vector>; -double Posterior(const VectorC &image, - const VectorC &measurements, - const sopt::LinearTransform &measurement_operator, - const double sigma, - const double gamma) -{ - const auto residuals = (measurement_operator * image) - measurements; - return residuals.squaredNorm() / (2 * sigma * sigma) + image.cwiseAbs().sum() * gamma; -} - int main(int argc, char **argv) { if(argc != 7) @@ -81,23 +71,15 @@ int main(int argc, char **argv) // posterior = likelihood + prior // Likelihood = |y - Phi(x)|^2 / sigma^2 (L2 norm) // Prior = Sum(|x_i|) * gamma (L1 norm) - //auto Posterior = [&measurement_data, measurement_operator, sigma, - // gamma](const VectorC &image) { - // { - // - // } - //}; + auto Posterior = [&measurement_data, measurement_operator, sigma, gamma](const VectorC &image) { + { + const auto residuals = (*measurement_operator * image) - measurement_data.vis; + return residuals.squaredNorm() / (2 * sigma * sigma) + image.cwiseAbs().sum() * gamma; + } + }; - const double reference_posterior = Posterior(reference_vector, - measurement_data.vis, - *measurement_operator, - sigma, - gamma); - const double surrogate_posterior = Posterior(surrogate_vector, - measurement_data.vis, - *measurement_operator, - sigma, - gamma); + const double reference_posterior = Posterior(reference_vector); + const double surrogate_posterior = Posterior(surrogate_vector); // Threshold for surrogate image posterior to be within confidence limit const double N = imsize_x * imsize_y; From 5b8da1218b3315c36455a59ba829813b95b7a9f1 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 28 Sep 2023 13:13:56 +0100 Subject: [PATCH 04/32] Add yaml parsing first instance --- cpp/uncertainty_quantification/uq_main.cc | 69 +++++++++++++++-------- 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index 2eee8b1e9..67665f0f9 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -4,31 +4,57 @@ #include "purify/measurement_operator_factory.h" #include "sopt/objective_functions.h" #include +#include "yaml-cpp/yaml.h" +#include "purify/yaml-parser.h" using VectorC = sopt::Vector>; int main(int argc, char **argv) { - if(argc != 7) + if(argc != 2) { - std::cout << "Please provide the following six arguments: " << std::endl; - std::cout << "Path for measurement data." << std::endl; - std::cout << "Path for reference image (.fits file)." << std::endl; - std::cout << "Path for surrogate iamge (.fits file)." << std::endl; - std::cout << "Confidence interval." << std::endl; - std::cout << "sigma (Gaussian Likelihood parameter)." << std::endl; - std::cout << "gamma (scaling of L1-norm prior)." << std::endl; + std::cout << "purify_UQ should be run using a single additional argument, which is the path to the config (yaml) file." << std::endl; + std::cout << "purify_UQ " << std::endl; + std::cout << std::endl; + std::cout << "For more information about the contents of the config file please consult the README." << std::endl; return 1; } - const std::string measurements_path = argv[1]; - const std::string ref_image_path = argv[2]; - const std::string surrogate_image_path = argv[3]; - const double confidence = strtod(argv[4], nullptr); - const double alpha = 1 - confidence; - const double sigma = strtod(argv[5], nullptr); - const double gamma = strtod(argv[6], nullptr); - + // Load and parse the config for parameters + const std::string config_path = argv[1]; + const YAML::Node UQ_config = YAML::LoadFile(config_path); + + const std::string measurements_path = UQ_config["measurements_path"].as(); + const std::string ref_image_path = UQ_config["reference_image_path"].as(); + const std::string surrogate_image_path = UQ_config["surrogate_image_path"].as(); + const std::string purify_config_path = UQ_config["purify_config_path"].as(); + double confidence; + double alpha; + if((UQ_config["confidence_interval"]) && (UQ_config["alpha"])) + { + std::cout << "Config should only contain one of 'confidence_interval' or 'alpha'." << std::endl; + return 1; + } + if(UQ_config["confidence_interval"]) + { + confidence = UQ_config["confidence_interval"].as(); + alpha = 1-confidence; + } + else if(UQ_config["alpha"]) + { + alpha = UQ_config["alpha"].as(); + confidence = 1 - alpha; + } + else + { + std::cout << "Config file must contain either 'confidence_interval' or 'alpha' as a parameter." << std::endl; + return 1; + } + const double sigma = UQ_config["sigma"].as(); + const double gamma = UQ_config["gamma"].as(); + purify::YamlParser purify_config = purify::YamlParser(purify_config_path); + + // Load the images and measurements const purify::utilities::vis_params measurement_data = purify::utilities::read_visibility(measurements_path, false); const auto reference_image = purify::pfitsio::read2d(ref_image_path); @@ -53,12 +79,11 @@ int main(int argc, char **argv) measurement_data, imsize_y, imsize_x, - 1, - 1, - 2, - purify::kernels::kernel_from_string.at("kb"), - 4, - 4 + purify_config.cellsizey(), + purify_config.cellsizex(), + purify_config.oversampling(), + purify::kernels::kernel_from_string.at(purify_config.kernel()), + purify_config.sim_J() ); if( ((*measurement_operator) * reference_vector).size() != measurement_data.vis.size()) From 58bacb8f6e88a466bca319d6b04ab82795bb358e Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 28 Sep 2023 13:14:52 +0100 Subject: [PATCH 05/32] move wavelet op construction into function --- cpp/main.cc | 56 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/cpp/main.cc b/cpp/main.cc index 20a7ba539..eb96d3760 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -22,6 +22,30 @@ #include using namespace purify; +struct waveletInfo +{ + std::shared_ptr> transform; + t_uint sara_size; +}; + +waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed_wavelet_operator &wop_algo) +{ + std::vector> sara; + for (size_t i = 0; i < params.wavelet_basis().size(); i++) + sara.push_back(std::make_tuple(params.wavelet_basis().at(i), params.wavelet_levels())); + t_uint sara_size = 0; +#ifdef PURIFY_MPI + { + auto const world = sopt::mpi::Communicator::World(); + if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) + sara = sopt::wavelets::distribute_sara(sara, world); + } +#endif + auto const wavelets_transform = factory::wavelet_operator_factory>( + wop_algo, sara, params.height(), params.width(), sara_size); + return {wavelets_transform, sara_size}; +} + int main(int argc, const char **argv) { std::srand(static_cast(std::time(0))); std::mt19937 mersnne(std::time(0)); @@ -409,19 +433,7 @@ int main(int argc, const char **argv) { pfitsio::write2d(dirty_image / beam_units, dirty_header, true); } // create wavelet operator - std::vector> sara; - for (size_t i = 0; i < params.wavelet_basis().size(); i++) - sara.push_back(std::make_tuple(params.wavelet_basis().at(i), params.wavelet_levels())); - t_uint sara_size = 0; -#ifdef PURIFY_MPI - { - auto const world = sopt::mpi::Communicator::World(); - if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) - sara = sopt::wavelets::distribute_sara(sara, world); - } -#endif - auto const wavelets_transform = factory::wavelet_operator_factory>( - wop_algo, sara, params.height(), params.width(), sara_size); + const waveletInfo wavelets = createWaveletOperator(params, wop_algo); // Create algorithm std::shared_ptr> padmm; @@ -429,8 +441,8 @@ int main(int argc, const char **argv) { std::shared_ptr> primaldual; if (params.algorithm() == "padmm") padmm = factory::padmm_factory>( - params.mpiAlgorithm(), measurements_transform, wavelets_transform, uv_data, - sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), sara_size, + params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data, + sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), wavelets.sara_size, params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(), (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and (not params.positiveValueConstraint()), @@ -438,10 +450,10 @@ int main(int argc, const char **argv) { params.epsilonConvergenceScaling(), operator_norm); if (params.algorithm() == "fb") fb = factory::fb_factory>( - params.mpiAlgorithm(), measurements_transform, wavelets_transform, uv_data, + params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data, sigma * params.epsilonScaling() / flux_scale, params.stepsize() * std::pow(sigma * params.epsilonScaling() / flux_scale, 2), - params.regularisation_parameter(), params.height(), params.width(), sara_size, + params.regularisation_parameter(), params.height(), params.width(), wavelets.sara_size, params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(), (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and (not params.positiveValueConstraint()), @@ -449,8 +461,8 @@ int main(int argc, const char **argv) { params.model_path(), params.gProximalType()); if (params.algorithm() == "primaldual") primaldual = factory::primaldual_factory>( - params.mpiAlgorithm(), measurements_transform, wavelets_transform, uv_data, - sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), sara_size, + params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data, + sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), wavelets.sara_size, params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(), params.relVarianceConvergence(), params.epsilonConvergenceScaling(), operator_norm); // Add primal dual preconditioning @@ -478,21 +490,21 @@ int main(int argc, const char **argv) { // Adding step size update to algorithm factory::add_updater>( algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol, - update_header_res, params.height(), params.width(), sara_size, using_mpi, beam_units); + update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi, beam_units); } if (params.algorithm() == "primaldual") { const std::weak_ptr> algo_weak(primaldual); // Adding step size update to algorithm factory::add_updater>( algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol, - update_header_res, params.height(), params.width(), sara_size, using_mpi, beam_units); + update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi, beam_units); } if (params.algorithm() == "fb") { const std::weak_ptr> algo_weak(fb); // Adding step size update to algorithm factory::add_updater>( algo_weak, 0, params.update_tolerance(), 0, update_header_sol, update_header_res, - params.height(), params.width(), sara_size, using_mpi, beam_units); + params.height(), params.width(), wavelets.sara_size, using_mpi, beam_units); } PURIFY_HIGH_LOG("Starting sopt!"); From 385f263ac69ace769c8ed63e37432082233e5d78 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 28 Sep 2023 13:53:01 +0100 Subject: [PATCH 06/32] Move measurement op & input data setup out of main --- cpp/main.cc | 91 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 58 insertions(+), 33 deletions(-) diff --git a/cpp/main.cc b/cpp/main.cc index eb96d3760..8c3f66933 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -46,42 +46,15 @@ waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed return {wavelets_transform, sara_size}; } -int main(int argc, const char **argv) { - std::srand(static_cast(std::time(0))); - std::mt19937 mersnne(std::time(0)); - sopt::logging::initialize(); - purify::logging::initialize(); - - // Read config file path from command line - if (argc == 1) { - PURIFY_HIGH_LOG("Specify the config file full path. Aborting."); - return 1; - } - - std::string file_path = argv[1]; - YamlParser params = YamlParser(file_path); - if (params.version() != purify::version()) - throw std::runtime_error( - "Using purify version " + purify::version() + - " but the configuration file expects version " + params.version() + - ". Please updated the config version manually to be compatable with the new version."); - +std::tuple selectOperators(YamlParser ¶ms) +{ factory::distributed_measurement_operator mop_algo = (not params.gpu()) ? factory::distributed_measurement_operator::serial : factory::distributed_measurement_operator::gpu_serial; factory::distributed_wavelet_operator wop_algo = factory::distributed_wavelet_operator::serial; bool using_mpi = false; - std::vector image_index = std::vector(); - std::vector w_stacks = std::vector(); - -#ifdef PURIFY_MPI - auto const session = sopt::mpi::init(argc, argv); -#endif - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifdef PURIFY_MPI - auto const world = sopt::mpi::Communicator::World(); -#else +#ifndef PURIFY_MPI throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); #endif mop_algo = (not params.gpu()) @@ -99,13 +72,28 @@ int main(int argc, const char **argv) { } using_mpi = true; } + return {mop_algo, wop_algo, using_mpi}; +} - sopt::logging::set_level(params.logging()); - purify::logging::set_level(params.logging()); +struct input_data +{ + utilities::vis_params uv_data; + t_real sigma; + Vector measurement_op_eigen_vector; + std::vector image_index; + std::vector w_stacks; +}; - // Read or generate input data +input_data getInputData(YamlParser ¶ms, + const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, + const bool using_mpi) +{ utilities::vis_params uv_data; t_real sigma; + std::vector image_index = std::vector(); + std::vector w_stacks = std::vector(); + Vector measurement_op_eigen_vector = Vector::Ones(params.width() * params.height()); // read eigen vector for power method @@ -277,6 +265,43 @@ int main(int argc, const char **argv) { params.oversampling()), widefield::equivalent_miriad_cell_size(params.cellsizey(), params.height(), params.oversampling())); + + return {uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks}; +} + +int main(int argc, const char **argv) { + std::srand(static_cast(std::time(0))); + std::mt19937 mersnne(std::time(0)); + sopt::logging::initialize(); + purify::logging::initialize(); + + // Read config file path from command line + if (argc == 1) { + PURIFY_HIGH_LOG("Specify the config file full path. Aborting."); + return 1; + } + + std::string file_path = argv[1]; + YamlParser params = YamlParser(file_path); + if (params.version() != purify::version()) + throw std::runtime_error( + "Using purify version " + purify::version() + + " but the configuration file expects version " + params.version() + + ". Please updated the config version manually to be compatable with the new version."); + +#ifdef PURIFY_MPI + auto const session = sopt::mpi::init(argc, argv); +#endif + + const auto [mop_algo, wop_algo, using_mpi] = selectOperators(params); + + sopt::logging::set_level(params.logging()); + purify::logging::set_level(params.logging()); + + // Read or generate input data + auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = + getInputData(params, mop_algo, wop_algo, using_mpi); + // create measurement operator std::shared_ptr>> measurements_transform; if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and From 5b75d7c873306e089e4eb499d42c8306e553d44b Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Sun, 1 Oct 2023 16:28:57 +0100 Subject: [PATCH 07/32] Refactor out measurement operator creation --- cpp/main.cc | 102 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 66 insertions(+), 36 deletions(-) diff --git a/cpp/main.cc b/cpp/main.cc index 8c3f66933..be784ff5a 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -75,7 +75,7 @@ std::tuple w_stacks; }; -input_data getInputData(YamlParser ¶ms, +inputData getInputData(YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi) @@ -269,41 +269,22 @@ input_data getInputData(YamlParser ¶ms, return {uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks}; } -int main(int argc, const char **argv) { - std::srand(static_cast(std::time(0))); - std::mt19937 mersnne(std::time(0)); - sopt::logging::initialize(); - purify::logging::initialize(); - - // Read config file path from command line - if (argc == 1) { - PURIFY_HIGH_LOG("Specify the config file full path. Aborting."); - return 1; - } - - std::string file_path = argv[1]; - YamlParser params = YamlParser(file_path); - if (params.version() != purify::version()) - throw std::runtime_error( - "Using purify version " + purify::version() + - " but the configuration file expects version " + params.version() + - ". Please updated the config version manually to be compatable with the new version."); - -#ifdef PURIFY_MPI - auto const session = sopt::mpi::init(argc, argv); -#endif - - const auto [mop_algo, wop_algo, using_mpi] = selectOperators(params); - - sopt::logging::set_level(params.logging()); - purify::logging::set_level(params.logging()); - - // Read or generate input data - auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = - getInputData(params, mop_algo, wop_algo, using_mpi); +struct measurementOpInfo +{ + std::shared_ptr>> measurement_transform; + t_real operator_norm; +}; - // create measurement operator - std::shared_ptr>> measurements_transform; +measurementOpInfo createMeasurementOperator(YamlParser ¶ms, + const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, + const bool using_mpi, + const std::vector &image_index, + const std::vector &w_stacks, + const utilities::vis_params &uv_data, + Vector &measurement_op_eigen_vector) +{ + std::shared_ptr>> measurements_transform; if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and mop_algo != factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all) measurements_transform = @@ -354,6 +335,55 @@ int main(int argc, const char **argv) { measurement_op_eigen_vector = std::get<1>(power_method_result); operator_norm = std::get<0>(power_method_result); } + + return {measurements_transform, operator_norm}; +} + +int main(int argc, const char **argv) { + std::srand(static_cast(std::time(0))); + std::mt19937 mersnne(std::time(0)); + sopt::logging::initialize(); + purify::logging::initialize(); + + // Read config file path from command line + if (argc == 1) { + PURIFY_HIGH_LOG("Specify the config file full path. Aborting."); + return 1; + } + + std::string file_path = argv[1]; + YamlParser params = YamlParser(file_path); + if (params.version() != purify::version()) + throw std::runtime_error( + "Using purify version " + purify::version() + + " but the configuration file expects version " + params.version() + + ". Please updated the config version manually to be compatable with the new version."); + +#ifdef PURIFY_MPI + auto const session = sopt::mpi::init(argc, argv); +#endif + + const auto [mop_algo, wop_algo, using_mpi] = selectOperators(params); + + sopt::logging::set_level(params.logging()); + purify::logging::set_level(params.logging()); + + // Read or generate input data + auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = getInputData(params, + mop_algo, + wop_algo, + using_mpi); + + // create measurement operator + auto [measurements_transform, operator_norm] = createMeasurementOperator(params, + mop_algo, + wop_algo, + using_mpi, + image_index, + w_stacks, + uv_data, + measurement_op_eigen_vector); + PURIFY_LOW_LOG("Value of operator norm is {}", operator_norm); t_real const flux_scale = 1.; uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale; From d8a63eef2f2685cdeedc3ebcc4719d1d88ff4e74 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Sun, 1 Oct 2023 16:31:46 +0100 Subject: [PATCH 08/32] Move wavelet operator call closer to other setup calls --- cpp/main.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/main.cc b/cpp/main.cc index be784ff5a..0834b2189 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -384,6 +384,9 @@ int main(int argc, const char **argv) { uv_data, measurement_op_eigen_vector); + // create wavelet operator + const waveletInfo wavelets = createWaveletOperator(params, wop_algo); + PURIFY_LOW_LOG("Value of operator norm is {}", operator_norm); t_real const flux_scale = 1.; uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale; @@ -487,8 +490,7 @@ int main(int argc, const char **argv) { } else { pfitsio::write2d(dirty_image / beam_units, dirty_header, true); } - // create wavelet operator - const waveletInfo wavelets = createWaveletOperator(params, wop_algo); + // Create algorithm std::shared_ptr> padmm; From 38320000872ee3c0a8e68393aa79aafedb472674 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Sun, 1 Oct 2023 18:04:49 +0100 Subject: [PATCH 09/32] Refactor out save functions --- cpp/main.cc | 215 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 136 insertions(+), 79 deletions(-) diff --git a/cpp/main.cc b/cpp/main.cc index 0834b2189..fa954788b 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -84,10 +84,10 @@ struct inputData std::vector w_stacks; }; -inputData getInputData(YamlParser ¶ms, - const factory::distributed_measurement_operator mop_algo, - const factory::distributed_wavelet_operator wop_algo, - const bool using_mpi) +inputData getInputData(const YamlParser ¶ms, + const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, + const bool using_mpi) { utilities::vis_params uv_data; t_real sigma; @@ -275,7 +275,7 @@ struct measurementOpInfo t_real operator_norm; }; -measurementOpInfo createMeasurementOperator(YamlParser ¶ms, +measurementOpInfo createMeasurementOperator(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, @@ -339,60 +339,8 @@ measurementOpInfo createMeasurementOperator(YamlParser ¶ms, return {measurements_transform, operator_norm}; } -int main(int argc, const char **argv) { - std::srand(static_cast(std::time(0))); - std::mt19937 mersnne(std::time(0)); - sopt::logging::initialize(); - purify::logging::initialize(); - - // Read config file path from command line - if (argc == 1) { - PURIFY_HIGH_LOG("Specify the config file full path. Aborting."); - return 1; - } - - std::string file_path = argv[1]; - YamlParser params = YamlParser(file_path); - if (params.version() != purify::version()) - throw std::runtime_error( - "Using purify version " + purify::version() + - " but the configuration file expects version " + params.version() + - ". Please updated the config version manually to be compatable with the new version."); - -#ifdef PURIFY_MPI - auto const session = sopt::mpi::init(argc, argv); -#endif - - const auto [mop_algo, wop_algo, using_mpi] = selectOperators(params); - - sopt::logging::set_level(params.logging()); - purify::logging::set_level(params.logging()); - - // Read or generate input data - auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = getInputData(params, - mop_algo, - wop_algo, - using_mpi); - - // create measurement operator - auto [measurements_transform, operator_norm] = createMeasurementOperator(params, - mop_algo, - wop_algo, - using_mpi, - image_index, - w_stacks, - uv_data, - measurement_op_eigen_vector); - - // create wavelet operator - const waveletInfo wavelets = createWaveletOperator(params, wop_algo); - - PURIFY_LOW_LOG("Value of operator norm is {}", operator_norm); - t_real const flux_scale = 1.; - uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale; - - // Save some things before applying the algorithm - // the config yaml file - this also generates the output directory and the timestamp +void initOutDirectoryWithConfig(YamlParser ¶ms) +{ if (params.mpiAlgorithm() != factory::algo_distribution::serial) { #ifdef PURIFY_MPI auto const world = sopt::mpi::Communicator::World(); @@ -404,20 +352,34 @@ int main(int argc, const char **argv) { } else { params.writeOutput(); } - const std::string out_dir = params.output_prefix() + "/output_" + params.timestamp(); - // Creating header for saving output images during iterations +} + +struct Headers +{ + pfitsio::header_params solution_header; + pfitsio::header_params residuals_header; + pfitsio::header_params def_header; +}; + +Headers genHeaders(const YamlParser ¶ms, const utilities::vis_params &uv_data) +{ const pfitsio::header_params update_header_sol = - pfitsio::header_params(out_dir + "/sol_update.fits", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, + pfitsio::header_params(params.output_path() + "/sol_update.fits", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, params.measurements_polarization(), params.cellsizex(), params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); const pfitsio::header_params update_header_res = - pfitsio::header_params(out_dir + "/res_update.fits", "Jy/Beam", 1, uv_data.ra, uv_data.dec, + pfitsio::header_params(params.output_path() + "/res_update.fits", "Jy/Beam", 1, uv_data.ra, uv_data.dec, params.measurements_polarization(), params.cellsizex(), params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); const pfitsio::header_params def_header = pfitsio::header_params( "", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, params.measurements_polarization(), params.cellsizex(), params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); - // the eigenvector + + return {update_header_sol, update_header_res, def_header}; +} + +void saveMeasurementEigenVector(const YamlParser ¶ms, const Vector &measurement_op_eigen_vector) +{ if (params.mpiAlgorithm() != factory::algo_distribution::serial) { #ifdef PURIFY_MPI auto const world = sopt::mpi::Communicator::World(); @@ -427,19 +389,29 @@ int main(int argc, const char **argv) { #endif { pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), - out_dir + "/eigenvector_real.fits", "pix", true); + params.output_path() + "/eigenvector_real.fits", "pix", true); pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), - out_dir + "/eigenvector_imag.fits", "pix", true); + params.output_path() + "/eigenvector_imag.fits", "pix", true); } } else { pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), - out_dir + "/eigenvector_real.fits", "pix", true); + params.output_path() + "/eigenvector_real.fits", "pix", true); pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), - out_dir + "/eigenvector_imag.fits", "pix", true); + params.output_path() + "/eigenvector_imag.fits", "pix", true); } - // the psf - pfitsio::header_params psf_header = def_header; - psf_header.fits_name = out_dir + "/psf.fits"; +} + +void savePSF(const YamlParser ¶ms, + const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, + const t_real flux_scale, + const t_real sigma, + const t_real operator_norm, + const t_real beam_units) +{ + pfitsio::header_params psf_header = def_header; + psf_header.fits_name = params.output_path() + "/psf.fits"; psf_header.pix_units = "Jy/Pixel"; const Vector psf = measurements_transform->adjoint() * (uv_data.weights / flux_scale); const Image psf_image = @@ -447,11 +419,9 @@ int main(int argc, const char **argv) { PURIFY_HIGH_LOG( "Peak of PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", psf_image(static_cast(params.width() * 0.5 + params.height() * 0.5 * params.width()))); - t_real beam_units = 1.; if (params.mpiAlgorithm() != factory::algo_distribution::serial) { #ifdef PURIFY_MPI auto const world = sopt::mpi::Communicator::World(); - beam_units = world.all_sum_all(uv_data.size()) / flux_scale / flux_scale; PURIFY_LOW_LOG( "Expected image domain residual RMS is {} jy/beam", sigma * params.epsilonScaling() * operator_norm / @@ -462,7 +432,6 @@ int main(int argc, const char **argv) { #endif pfitsio::write2d(psf_image, psf_header, true); } else { - beam_units = uv_data.size() / flux_scale / flux_scale; PURIFY_LOW_LOG("Expected image domain residual RMS is {} jy/beam", sigma * params.epsilonScaling() * operator_norm / (std::sqrt(params.width() * params.height()) * uv_data.size())); @@ -472,9 +441,16 @@ int main(int argc, const char **argv) { "Theoretical calculation for peak PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", beam_units); PURIFY_HIGH_LOG("Effective sigma is {} Jy", sigma * params.epsilonScaling()); - // the dirty image - pfitsio::header_params dirty_header = def_header; - dirty_header.fits_name = out_dir + "/dirty.fits"; +} + +void saveDirtyImage(const YamlParser ¶ms, + const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, + const t_real beam_units) +{ + pfitsio::header_params dirty_header = def_header; + dirty_header.fits_name = params.output_path() + "/dirty.fits"; dirty_header.pix_units = "Jy/Beam"; const Vector dimage = measurements_transform->adjoint() * uv_data.vis; const Image dirty_image = @@ -490,6 +466,87 @@ int main(int argc, const char **argv) { } else { pfitsio::write2d(dirty_image / beam_units, dirty_header, true); } +} + +int main(int argc, const char **argv) { + std::srand(static_cast(std::time(0))); + std::mt19937 mersnne(std::time(0)); + sopt::logging::initialize(); + purify::logging::initialize(); + + // Read config file path from command line + if (argc == 1) { + PURIFY_HIGH_LOG("Specify the config file full path. Aborting."); + return 1; + } + + std::string file_path = argv[1]; + YamlParser params = YamlParser(file_path); + if (params.version() != purify::version()) + throw std::runtime_error( + "Using purify version " + purify::version() + + " but the configuration file expects version " + params.version() + + ". Please updated the config version manually to be compatable with the new version."); + +#ifdef PURIFY_MPI + auto const session = sopt::mpi::init(argc, argv); +#endif + + const auto [mop_algo, wop_algo, using_mpi] = selectOperators(params); + + sopt::logging::set_level(params.logging()); + purify::logging::set_level(params.logging()); + + // Read or generate input data + auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = getInputData(params, + mop_algo, + wop_algo, + using_mpi); + + // create measurement operator + auto [measurements_transform, operator_norm] = createMeasurementOperator(params, + mop_algo, + wop_algo, + using_mpi, + image_index, + w_stacks, + uv_data, + measurement_op_eigen_vector); + + // create wavelet operator + const waveletInfo wavelets = createWaveletOperator(params, wop_algo); + + PURIFY_LOW_LOG("Value of operator norm is {}", operator_norm); + t_real const flux_scale = 1.; + uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale; + + // Save some things before applying the algorithm + // the config yaml file - this also generates the output directory and the timestamp + initOutDirectoryWithConfig(params); + + // Creating header for saving output images during iterations + const auto [update_header_sol, update_header_res, def_header] = genHeaders(params, uv_data); + + // the eigenvector + saveMeasurementEigenVector(params, measurement_op_eigen_vector); + + // the psf + t_real beam_units = 1.0; + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifdef PURIFY_MPI + auto const world = sopt::mpi::Communicator::World(); + beam_units = world.all_sum_all(uv_data.size()) / flux_scale / flux_scale; +#else + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + } else { + beam_units = uv_data.size() / flux_scale / flux_scale; + } + + savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, operator_norm, beam_units); + + // the dirty image + saveDirtyImage(params, def_header, measurements_transform, uv_data, beam_units); // Create algorithm @@ -568,7 +625,7 @@ int main(int argc, const char **argv) { Image image; Image residual_image; pfitsio::header_params purified_header = def_header; - purified_header.fits_name = out_dir + "/purified.fits"; + purified_header.fits_name = params.output_path() + "/purified.fits"; const Vector estimate_image = (params.warm_start() != "") ? Vector::Map(pfitsio::read2d(params.warm_start()).data(), @@ -631,7 +688,7 @@ int main(int argc, const char **argv) { } // the residuals pfitsio::header_params residuals_header = purified_header; - residuals_header.fits_name = out_dir + "/residuals.fits"; + residuals_header.fits_name = params.output_path() + "/residuals.fits"; residuals_header.pix_units = "Jy/Beam"; if (params.mpiAlgorithm() != factory::algo_distribution::serial) { #ifdef PURIFY_MPI From e6d013b66f135804339f86043631b021c5f6476b Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Sun, 1 Oct 2023 18:05:36 +0100 Subject: [PATCH 10/32] Add out_path to parser to prevent inconsistencies --- cpp/purify/yaml-parser.cc | 2 +- cpp/purify/yaml-parser.h | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/cpp/purify/yaml-parser.cc b/cpp/purify/yaml-parser.cc index 2a20fee04..0a3a11f7b 100644 --- a/cpp/purify/yaml-parser.cc +++ b/cpp/purify/yaml-parser.cc @@ -303,7 +303,7 @@ void YamlParser::writeOutput() { base_file_name.substr((file_path.size() ? file_path.size() + 1 : 0), base_file_name.size()); // Construct output directory structure and file name boost::filesystem::path const path(this->output_prefix_); - std::string const out_path = output_prefix_ + "/output_" + std::string(this->timestamp()); + out_path = output_prefix_ + "/output_" + std::string(this->timestamp()); mkdir_recursive(out_path); std::string out_filename = out_path + "/" + base_file_name + "_save.yaml"; diff --git a/cpp/purify/yaml-parser.h b/cpp/purify/yaml-parser.h index ede5b5373..d5e2b40bf 100644 --- a/cpp/purify/yaml-parser.h +++ b/cpp/purify/yaml-parser.h @@ -81,7 +81,7 @@ class YamlParser { TYPE NAME##_ = VALUE; \ \ public: \ - TYPE NAME() { return NAME##_; }; + TYPE NAME() const { return NAME##_; }; YAML_MACRO(std::string, filepath, "") YAML_MACRO(std::string, version, "") @@ -143,9 +143,15 @@ class YamlParser { YAML_MACRO(std::string, model_path, "") YAML_MACRO(factory::g_proximal_type, gProximalType, factory::g_proximal_type::L1GProximal) + std::string output_path() const + { + return out_path; + } + #undef YAML_MACRO private: YAML::Node config_file; + std::string out_path; template T get(const YAML::Node& node_map, const std::initializer_list indicies); From 4d160cb33f5edad9020ada58e187870280dac2b0 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Sun, 1 Oct 2023 18:12:14 +0100 Subject: [PATCH 11/32] Returning a struct for consistency --- cpp/main.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cpp/main.cc b/cpp/main.cc index fa954788b..b0aba65a8 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -46,7 +46,14 @@ waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed return {wavelets_transform, sara_size}; } -std::tuple selectOperators(YamlParser ¶ms) +struct OperatorsInfo +{ + factory::distributed_measurement_operator mop_algo; + factory::distributed_wavelet_operator wop_algo; + bool using_mpi; +}; + +OperatorsInfo selectOperators(YamlParser ¶ms) { factory::distributed_measurement_operator mop_algo = (not params.gpu()) ? factory::distributed_measurement_operator::serial From 9087faabff1e71ef2ddf88600361bab76766c88e Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Sun, 1 Oct 2023 18:55:44 +0100 Subject: [PATCH 12/32] Move refactored fns into new (poorly named) file --- cpp/main.cc | 454 +------------------------------------- cpp/purify/CMakeLists.txt | 2 +- cpp/purify/setup_utils.cc | 422 +++++++++++++++++++++++++++++++++++ cpp/purify/setup_utils.h | 83 +++++++ 4 files changed, 508 insertions(+), 453 deletions(-) create mode 100644 cpp/purify/setup_utils.cc create mode 100644 cpp/purify/setup_utils.h diff --git a/cpp/main.cc b/cpp/main.cc index b0aba65a8..38111a7ae 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -20,460 +20,10 @@ #include #include #include -using namespace purify; - -struct waveletInfo -{ - std::shared_ptr> transform; - t_uint sara_size; -}; - -waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed_wavelet_operator &wop_algo) -{ - std::vector> sara; - for (size_t i = 0; i < params.wavelet_basis().size(); i++) - sara.push_back(std::make_tuple(params.wavelet_basis().at(i), params.wavelet_levels())); - t_uint sara_size = 0; -#ifdef PURIFY_MPI - { - auto const world = sopt::mpi::Communicator::World(); - if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) - sara = sopt::wavelets::distribute_sara(sara, world); - } -#endif - auto const wavelets_transform = factory::wavelet_operator_factory>( - wop_algo, sara, params.height(), params.width(), sara_size); - return {wavelets_transform, sara_size}; -} - -struct OperatorsInfo -{ - factory::distributed_measurement_operator mop_algo; - factory::distributed_wavelet_operator wop_algo; - bool using_mpi; -}; +#include "purify/setup_utils.h" -OperatorsInfo selectOperators(YamlParser ¶ms) -{ - factory::distributed_measurement_operator mop_algo = - (not params.gpu()) ? factory::distributed_measurement_operator::serial - : factory::distributed_measurement_operator::gpu_serial; - factory::distributed_wavelet_operator wop_algo = factory::distributed_wavelet_operator::serial; - bool using_mpi = false; - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifndef PURIFY_MPI - throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); -#endif - mop_algo = (not params.gpu()) - ? factory::distributed_measurement_operator::mpi_distribute_image - : factory::distributed_measurement_operator::gpu_mpi_distribute_image; - if (params.mpi_all_to_all()) - mop_algo = (not params.gpu()) - ? factory::distributed_measurement_operator::mpi_distribute_all_to_all - : factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all; - wop_algo = factory::distributed_wavelet_operator::mpi_sara; - if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) { - mop_algo = (not params.gpu()) ? factory::distributed_measurement_operator::serial - : factory::distributed_measurement_operator::serial; - wop_algo = factory::distributed_wavelet_operator::serial; - } - using_mpi = true; - } - return {mop_algo, wop_algo, using_mpi}; -} - -struct inputData -{ - utilities::vis_params uv_data; - t_real sigma; - Vector measurement_op_eigen_vector; - std::vector image_index; - std::vector w_stacks; -}; - -inputData getInputData(const YamlParser ¶ms, - const factory::distributed_measurement_operator mop_algo, - const factory::distributed_wavelet_operator wop_algo, - const bool using_mpi) -{ - utilities::vis_params uv_data; - t_real sigma; - std::vector image_index = std::vector(); - std::vector w_stacks = std::vector(); - - Vector measurement_op_eigen_vector = - Vector::Ones(params.width() * params.height()); - // read eigen vector for power method - if (params.eigenvector_real() != "" and params.eigenvector_imag() != "") { - t_int rows; - t_int cols; - t_int pols; - t_int chans; - Vector temp_real; - Vector temp_imag; - pfitsio::read3d(params.eigenvector_real(), temp_real, rows, cols, chans, pols); - if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1) - throw std::runtime_error("Image of measurement operator eigenvector is wrong size."); - pfitsio::read3d(params.eigenvector_imag(), temp_imag, rows, cols, chans, pols); - if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1) - throw std::runtime_error("Image of measurement operator eigenvector is wrong size."); - measurement_op_eigen_vector.real() = temp_real; - measurement_op_eigen_vector.imag() = temp_imag; - } - if (params.source() == purify::utilities::vis_source::measurements) { - PURIFY_HIGH_LOG("Input visibilities are from files:"); - for (size_t i = 0; i < params.measurements().size(); i++) - PURIFY_HIGH_LOG("{}", params.measurements()[i]); - sigma = params.measurements_sigma(); -#ifdef PURIFY_MPI - if (using_mpi) { - auto const world = sopt::mpi::Communicator::World(); - uv_data = read_measurements::read_measurements(params.measurements(), world, - distribute::plan::radial, true, stokes::I, - params.measurements_units()); - const t_real norm = - std::sqrt(world.all_sum_all( - (uv_data.weights.real().array() * uv_data.weights.real().array()).sum()) / - world.all_sum_all(uv_data.size())); - // normalise weights - uv_data.weights = uv_data.weights / norm; - // using no weights for now - // uv_data.weights = Vector::Ones(uv_data.size()); - } else -#endif - { - uv_data = read_measurements::read_measurements(params.measurements(), true, stokes::I, - params.measurements_units()); - const t_real norm = std::sqrt( - (uv_data.weights.real().array() * uv_data.weights.real().array()).sum() / uv_data.size()); - // normalising weights - uv_data.weights = uv_data.weights / norm; - // using no weights for now - // uv_data.weights = Vector::Ones(uv_data.size()); - } - if (params.conjugate_w()) uv_data = utilities::conjugate_w(uv_data); -#ifdef PURIFY_MPI - if (params.mpi_wstacking() and - (mop_algo == factory::distributed_measurement_operator::mpi_distribute_all_to_all or - mop_algo == factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all)) { - auto const world = sopt::mpi::Communicator::World(); - const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; - const t_real du = - widefield::pixel_to_lambda(params.cellsizex(), params.width(), params.oversampling()); - std::tie(uv_data, image_index, w_stacks) = utilities::w_stacking_with_all_to_all( - uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost); - } else if (params.mpi_wstacking()) { - auto const world = sopt::mpi::Communicator::World(); - const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; - uv_data = utilities::w_stacking(uv_data, world, params.kmeans_iters(), cost); - } -#endif - } else if (params.source() == purify::utilities::vis_source::simulation) { - PURIFY_HIGH_LOG("Input visibilities will be generated for random coverage."); - // TODO: move this to function (in utilities.h?) - auto image = pfitsio::read2d(params.skymodel()); - if (params.height() != image.rows() || params.width() != image.cols()) - throw std::runtime_error("Input image size (" + std::to_string(image.cols()) + "x" + - std::to_string(image.rows()) + ") is not equal to the input one (" + - std::to_string(params.width()) + "x" + - std::to_string(params.height()) + ")."); - t_int const number_of_pixels = image.size(); - t_int const number_of_vis = params.number_of_measurements(); - t_real const sigma_m = constant::pi / 4; - const t_real rms_w = params.w_rms(); // lambda - if (params.measurements().at(0) == "") { - uv_data = utilities::random_sample_density(number_of_vis, 0, sigma_m, rms_w); - uv_data.units = utilities::vis_units::radians; - uv_data.weights = Vector::Ones(uv_data.size()); - } else { -#ifdef PURIFY_MPI - if (using_mpi) { - auto const world = sopt::mpi::Communicator::World(); - uv_data = read_measurements::read_measurements(params.measurements(), world, - distribute::plan::radial, true, stokes::I, - params.measurements_units()); - } else -#endif - uv_data = read_measurements::read_measurements(params.measurements(), true, stokes::I, - params.measurements_units()); - uv_data.weights = Vector::Ones(uv_data.weights.size()); - } - if (params.conjugate_w()) uv_data = utilities::conjugate_w(uv_data); -#ifdef PURIFY_MPI - if (params.mpi_wstacking() and - (mop_algo == factory::distributed_measurement_operator::mpi_distribute_all_to_all or - mop_algo == factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all)) { - auto const world = sopt::mpi::Communicator::World(); - const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; - const t_real du = - widefield::pixel_to_lambda(params.cellsizex(), params.width(), params.oversampling()); - std::tie(uv_data, image_index, w_stacks) = utilities::w_stacking_with_all_to_all( - uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost); - } else if (params.mpi_wstacking()) { - auto const world = sopt::mpi::Communicator::World(); - const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; - uv_data = utilities::w_stacking(uv_data, world, params.kmeans_iters(), cost); - } -#endif - std::shared_ptr>> sky_measurements; - if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and - mop_algo != factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all) - sky_measurements = - (not params.wprojection()) - ? factory::measurement_operator_factory>( - mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), - params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.sim_J(), - params.mpi_wstacking()) - : factory::measurement_operator_factory>( - mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), - params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.Jw(), - params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); - else - sky_measurements = - (not params.wprojection()) - ? factory::all_to_all_measurement_operator_factory>( - mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), - params.cellsizey(), params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.sim_J(), - params.mpi_wstacking()) - : factory::all_to_all_measurement_operator_factory>( - mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), - params.cellsizey(), params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.Jw(), - params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); - uv_data.vis = - ((*sky_measurements) * Vector::Map(image.data(), image.size())).eval().array(); - sigma = utilities::SNR_to_standard_deviation(uv_data.vis, params.signal_to_noise()); - uv_data.vis = utilities::add_noise(uv_data.vis, 0., sigma); - } - t_real ideal_cell_x = widefield::estimate_cell_size(uv_data.u.cwiseAbs().maxCoeff(), - params.width(), params.oversampling()); - t_real ideal_cell_y = widefield::estimate_cell_size(uv_data.v.cwiseAbs().maxCoeff(), - params.height(), params.oversampling()); -#ifdef PURIFY_MPI - if (using_mpi) { - auto const comm = sopt::mpi::Communicator::World(); - ideal_cell_x = widefield::estimate_cell_size( - comm.all_reduce(uv_data.u.cwiseAbs().maxCoeff(), MPI_MAX), params.width(), - params.oversampling()); - ideal_cell_y = widefield::estimate_cell_size( - comm.all_reduce(uv_data.v.cwiseAbs().maxCoeff(), MPI_MAX), params.height(), - params.oversampling()); - } -#endif - PURIFY_HIGH_LOG( - "Using cell size {}\" x {}\", recommended from the uv coverage and field of view is " - "{}\"x{}\".", - params.cellsizey(), params.cellsizex(), ideal_cell_y, ideal_cell_x); - PURIFY_HIGH_LOG("The equivalent miriad cell size is: {}\" x {}\"", - widefield::equivalent_miriad_cell_size(params.cellsizex(), params.width(), - params.oversampling()), - widefield::equivalent_miriad_cell_size(params.cellsizey(), params.height(), - params.oversampling())); - - return {uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks}; -} - -struct measurementOpInfo -{ - std::shared_ptr>> measurement_transform; - t_real operator_norm; -}; - -measurementOpInfo createMeasurementOperator(const YamlParser ¶ms, - const factory::distributed_measurement_operator mop_algo, - const factory::distributed_wavelet_operator wop_algo, - const bool using_mpi, - const std::vector &image_index, - const std::vector &w_stacks, - const utilities::vis_params &uv_data, - Vector &measurement_op_eigen_vector) -{ - std::shared_ptr>> measurements_transform; - if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and - mop_algo != factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all) - measurements_transform = - (not params.wprojection()) - ? factory::measurement_operator_factory>( - mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), - params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jx(), - params.mpi_wstacking()) - : factory::measurement_operator_factory>( - mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), - params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jw(), - params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); - else - measurements_transform = - (not params.wprojection()) - ? factory::all_to_all_measurement_operator_factory>( - mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), - params.cellsizey(), params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jx(), - params.mpi_wstacking()) - : factory::all_to_all_measurement_operator_factory>( - mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), - params.cellsizey(), params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jw(), - params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); - t_real operator_norm = 1.; -#ifdef PURIFY_MPI - if (using_mpi) { - auto const comm = sopt::mpi::Communicator::World(); - auto power_method_result = - (params.mpiAlgorithm() != factory::algo_distribution::mpi_random_updates) - ? sopt::algorithm::power_method>( - *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(), - comm.broadcast(measurement_op_eigen_vector).eval()) - : sopt::algorithm::all_sum_all_power_method>( - comm, *measurements_transform, params.powMethod_iter(), - params.powMethod_tolerance(), comm.broadcast(measurement_op_eigen_vector).eval()); - measurement_op_eigen_vector = std::get<1>(power_method_result); - operator_norm = std::get<0>(power_method_result); - } else -#endif - { - auto power_method_result = sopt::algorithm::power_method>( - *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(), - measurement_op_eigen_vector); - measurement_op_eigen_vector = std::get<1>(power_method_result); - operator_norm = std::get<0>(power_method_result); - } - - return {measurements_transform, operator_norm}; -} - -void initOutDirectoryWithConfig(YamlParser ¶ms) -{ - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifdef PURIFY_MPI - auto const world = sopt::mpi::Communicator::World(); - if (world.is_root()) -#else - throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); -#endif - params.writeOutput(); - } else { - params.writeOutput(); - } -} - -struct Headers -{ - pfitsio::header_params solution_header; - pfitsio::header_params residuals_header; - pfitsio::header_params def_header; -}; - -Headers genHeaders(const YamlParser ¶ms, const utilities::vis_params &uv_data) -{ - const pfitsio::header_params update_header_sol = - pfitsio::header_params(params.output_path() + "/sol_update.fits", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, - params.measurements_polarization(), params.cellsizex(), - params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); - const pfitsio::header_params update_header_res = - pfitsio::header_params(params.output_path() + "/res_update.fits", "Jy/Beam", 1, uv_data.ra, uv_data.dec, - params.measurements_polarization(), params.cellsizex(), - params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); - const pfitsio::header_params def_header = pfitsio::header_params( - "", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, params.measurements_polarization(), - params.cellsizex(), params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); - - return {update_header_sol, update_header_res, def_header}; -} - -void saveMeasurementEigenVector(const YamlParser ¶ms, const Vector &measurement_op_eigen_vector) -{ - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifdef PURIFY_MPI - auto const world = sopt::mpi::Communicator::World(); - if (world.is_root()) -#else - throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); -#endif - { - pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), - params.output_path() + "/eigenvector_real.fits", "pix", true); - pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), - params.output_path() + "/eigenvector_imag.fits", "pix", true); - } - } else { - pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), - params.output_path() + "/eigenvector_real.fits", "pix", true); - pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), - params.output_path() + "/eigenvector_imag.fits", "pix", true); - } -} - -void savePSF(const YamlParser ¶ms, - const pfitsio::header_params &def_header, - const std::shared_ptr>> &measurements_transform, - const utilities::vis_params &uv_data, - const t_real flux_scale, - const t_real sigma, - const t_real operator_norm, - const t_real beam_units) -{ - pfitsio::header_params psf_header = def_header; - psf_header.fits_name = params.output_path() + "/psf.fits"; - psf_header.pix_units = "Jy/Pixel"; - const Vector psf = measurements_transform->adjoint() * (uv_data.weights / flux_scale); - const Image psf_image = - Image::Map(psf.data(), params.height(), params.width()).real(); - PURIFY_HIGH_LOG( - "Peak of PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", - psf_image(static_cast(params.width() * 0.5 + params.height() * 0.5 * params.width()))); - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifdef PURIFY_MPI - auto const world = sopt::mpi::Communicator::World(); - PURIFY_LOW_LOG( - "Expected image domain residual RMS is {} jy/beam", - sigma * params.epsilonScaling() * operator_norm / - (std::sqrt(params.width() * params.height()) * world.all_sum_all(uv_data.size()))); - if (world.is_root()) -#else - throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); -#endif - pfitsio::write2d(psf_image, psf_header, true); - } else { - PURIFY_LOW_LOG("Expected image domain residual RMS is {} jy/beam", - sigma * params.epsilonScaling() * operator_norm / - (std::sqrt(params.width() * params.height()) * uv_data.size())); - pfitsio::write2d(psf_image, psf_header, true); - } - PURIFY_HIGH_LOG( - "Theoretical calculation for peak PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", - beam_units); - PURIFY_HIGH_LOG("Effective sigma is {} Jy", sigma * params.epsilonScaling()); -} +using namespace purify; -void saveDirtyImage(const YamlParser ¶ms, - const pfitsio::header_params &def_header, - const std::shared_ptr>> &measurements_transform, - const utilities::vis_params &uv_data, - const t_real beam_units) -{ - pfitsio::header_params dirty_header = def_header; - dirty_header.fits_name = params.output_path() + "/dirty.fits"; - dirty_header.pix_units = "Jy/Beam"; - const Vector dimage = measurements_transform->adjoint() * uv_data.vis; - const Image dirty_image = - Image::Map(dimage.data(), params.height(), params.width()).real(); - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifdef PURIFY_MPI - auto const world = sopt::mpi::Communicator::World(); - if (world.is_root()) -#else - throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); -#endif - pfitsio::write2d(dirty_image / beam_units, dirty_header, true); - } else { - pfitsio::write2d(dirty_image / beam_units, dirty_header, true); - } -} int main(int argc, const char **argv) { std::srand(static_cast(std::time(0))); diff --git a/cpp/purify/CMakeLists.txt b/cpp/purify/CMakeLists.txt index 07c71a792..e8d28875d 100644 --- a/cpp/purify/CMakeLists.txt +++ b/cpp/purify/CMakeLists.txt @@ -35,7 +35,7 @@ set(HEADERS set(SOURCES utilities.cc pfitsio.cc kernels.cc wproj_utilities.cc operators.cc uvfits.cc yaml-parser.cc read_measurements.cc distribute.cc integration.cc wide_field_utilities.cc wkernel_integration.cc - wproj_operators.cc uvw_utilities.cc) + wproj_operators.cc uvw_utilities.cc setup_utils.cc) if(TARGET casacore::ms) list(APPEND SOURCES casacore.cc) diff --git a/cpp/purify/setup_utils.cc b/cpp/purify/setup_utils.cc new file mode 100644 index 000000000..0721075b6 --- /dev/null +++ b/cpp/purify/setup_utils.cc @@ -0,0 +1,422 @@ +#include "purify/setup_utils.h" +#include + +using namespace purify; + +waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed_wavelet_operator &wop_algo) +{ + std::vector> sara; + for (size_t i = 0; i < params.wavelet_basis().size(); i++) + sara.push_back(std::make_tuple(params.wavelet_basis().at(i), params.wavelet_levels())); + t_uint sara_size = 0; +#ifdef PURIFY_MPI + { + auto const world = sopt::mpi::Communicator::World(); + if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) + sara = sopt::wavelets::distribute_sara(sara, world); + } +#endif + auto const wavelets_transform = factory::wavelet_operator_factory>( + wop_algo, sara, params.height(), params.width(), sara_size); + return {wavelets_transform, sara_size}; +} + +OperatorsInfo selectOperators(YamlParser ¶ms) +{ + factory::distributed_measurement_operator mop_algo = + (not params.gpu()) ? factory::distributed_measurement_operator::serial + : factory::distributed_measurement_operator::gpu_serial; + factory::distributed_wavelet_operator wop_algo = factory::distributed_wavelet_operator::serial; + bool using_mpi = false; + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifndef PURIFY_MPI + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + mop_algo = (not params.gpu()) + ? factory::distributed_measurement_operator::mpi_distribute_image + : factory::distributed_measurement_operator::gpu_mpi_distribute_image; + if (params.mpi_all_to_all()) + mop_algo = (not params.gpu()) + ? factory::distributed_measurement_operator::mpi_distribute_all_to_all + : factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all; + wop_algo = factory::distributed_wavelet_operator::mpi_sara; + if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) { + mop_algo = (not params.gpu()) ? factory::distributed_measurement_operator::serial + : factory::distributed_measurement_operator::serial; + wop_algo = factory::distributed_wavelet_operator::serial; + } + using_mpi = true; + } + return {mop_algo, wop_algo, using_mpi}; +} + +inputData getInputData(const YamlParser ¶ms, + const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, + const bool using_mpi) +{ + utilities::vis_params uv_data; + t_real sigma; + std::vector image_index = std::vector(); + std::vector w_stacks = std::vector(); + + Vector measurement_op_eigen_vector = + Vector::Ones(params.width() * params.height()); + // read eigen vector for power method + if (params.eigenvector_real() != "" and params.eigenvector_imag() != "") { + t_int rows; + t_int cols; + t_int pols; + t_int chans; + Vector temp_real; + Vector temp_imag; + pfitsio::read3d(params.eigenvector_real(), temp_real, rows, cols, chans, pols); + if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1) + throw std::runtime_error("Image of measurement operator eigenvector is wrong size."); + pfitsio::read3d(params.eigenvector_imag(), temp_imag, rows, cols, chans, pols); + if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1) + throw std::runtime_error("Image of measurement operator eigenvector is wrong size."); + measurement_op_eigen_vector.real() = temp_real; + measurement_op_eigen_vector.imag() = temp_imag; + } + if (params.source() == purify::utilities::vis_source::measurements) { + PURIFY_HIGH_LOG("Input visibilities are from files:"); + for (size_t i = 0; i < params.measurements().size(); i++) + PURIFY_HIGH_LOG("{}", params.measurements()[i]); + sigma = params.measurements_sigma(); +#ifdef PURIFY_MPI + if (using_mpi) { + auto const world = sopt::mpi::Communicator::World(); + uv_data = read_measurements::read_measurements(params.measurements(), world, + distribute::plan::radial, true, stokes::I, + params.measurements_units()); + const t_real norm = + std::sqrt(world.all_sum_all( + (uv_data.weights.real().array() * uv_data.weights.real().array()).sum()) / + world.all_sum_all(uv_data.size())); + // normalise weights + uv_data.weights = uv_data.weights / norm; + // using no weights for now + // uv_data.weights = Vector::Ones(uv_data.size()); + } else +#endif + { + uv_data = read_measurements::read_measurements(params.measurements(), true, stokes::I, + params.measurements_units()); + const t_real norm = std::sqrt( + (uv_data.weights.real().array() * uv_data.weights.real().array()).sum() / uv_data.size()); + // normalising weights + uv_data.weights = uv_data.weights / norm; + // using no weights for now + // uv_data.weights = Vector::Ones(uv_data.size()); + } + if (params.conjugate_w()) uv_data = utilities::conjugate_w(uv_data); +#ifdef PURIFY_MPI + if (params.mpi_wstacking() and + (mop_algo == factory::distributed_measurement_operator::mpi_distribute_all_to_all or + mop_algo == factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all)) { + auto const world = sopt::mpi::Communicator::World(); + const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; + const t_real du = + widefield::pixel_to_lambda(params.cellsizex(), params.width(), params.oversampling()); + std::tie(uv_data, image_index, w_stacks) = utilities::w_stacking_with_all_to_all( + uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost); + } else if (params.mpi_wstacking()) { + auto const world = sopt::mpi::Communicator::World(); + const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; + uv_data = utilities::w_stacking(uv_data, world, params.kmeans_iters(), cost); + } +#endif + } else if (params.source() == purify::utilities::vis_source::simulation) { + PURIFY_HIGH_LOG("Input visibilities will be generated for random coverage."); + // TODO: move this to function (in utilities.h?) + auto image = pfitsio::read2d(params.skymodel()); + if (params.height() != image.rows() || params.width() != image.cols()) + throw std::runtime_error("Input image size (" + std::to_string(image.cols()) + "x" + + std::to_string(image.rows()) + ") is not equal to the input one (" + + std::to_string(params.width()) + "x" + + std::to_string(params.height()) + ")."); + t_int const number_of_pixels = image.size(); + t_int const number_of_vis = params.number_of_measurements(); + t_real const sigma_m = constant::pi / 4; + const t_real rms_w = params.w_rms(); // lambda + if (params.measurements().at(0) == "") { + uv_data = utilities::random_sample_density(number_of_vis, 0, sigma_m, rms_w); + uv_data.units = utilities::vis_units::radians; + uv_data.weights = Vector::Ones(uv_data.size()); + } else { +#ifdef PURIFY_MPI + if (using_mpi) { + auto const world = sopt::mpi::Communicator::World(); + uv_data = read_measurements::read_measurements(params.measurements(), world, + distribute::plan::radial, true, stokes::I, + params.measurements_units()); + } else +#endif + uv_data = read_measurements::read_measurements(params.measurements(), true, stokes::I, + params.measurements_units()); + uv_data.weights = Vector::Ones(uv_data.weights.size()); + } + if (params.conjugate_w()) uv_data = utilities::conjugate_w(uv_data); +#ifdef PURIFY_MPI + if (params.mpi_wstacking() and + (mop_algo == factory::distributed_measurement_operator::mpi_distribute_all_to_all or + mop_algo == factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all)) { + auto const world = sopt::mpi::Communicator::World(); + const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; + const t_real du = + widefield::pixel_to_lambda(params.cellsizex(), params.width(), params.oversampling()); + std::tie(uv_data, image_index, w_stacks) = utilities::w_stacking_with_all_to_all( + uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost); + } else if (params.mpi_wstacking()) { + auto const world = sopt::mpi::Communicator::World(); + const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; + uv_data = utilities::w_stacking(uv_data, world, params.kmeans_iters(), cost); + } +#endif + std::shared_ptr>> sky_measurements; + if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and + mop_algo != factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all) + sky_measurements = + (not params.wprojection()) + ? factory::measurement_operator_factory>( + mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), + params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.sim_J(), + params.mpi_wstacking()) + : factory::measurement_operator_factory>( + mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), + params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.Jw(), + params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); + else + sky_measurements = + (not params.wprojection()) + ? factory::all_to_all_measurement_operator_factory>( + mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), + params.cellsizey(), params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.sim_J(), + params.mpi_wstacking()) + : factory::all_to_all_measurement_operator_factory>( + mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), + params.cellsizey(), params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.Jw(), + params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); + uv_data.vis = + ((*sky_measurements) * Vector::Map(image.data(), image.size())).eval().array(); + sigma = utilities::SNR_to_standard_deviation(uv_data.vis, params.signal_to_noise()); + uv_data.vis = utilities::add_noise(uv_data.vis, 0., sigma); + } + t_real ideal_cell_x = widefield::estimate_cell_size(uv_data.u.cwiseAbs().maxCoeff(), + params.width(), params.oversampling()); + t_real ideal_cell_y = widefield::estimate_cell_size(uv_data.v.cwiseAbs().maxCoeff(), + params.height(), params.oversampling()); +#ifdef PURIFY_MPI + if (using_mpi) { + auto const comm = sopt::mpi::Communicator::World(); + ideal_cell_x = widefield::estimate_cell_size( + comm.all_reduce(uv_data.u.cwiseAbs().maxCoeff(), MPI_MAX), params.width(), + params.oversampling()); + ideal_cell_y = widefield::estimate_cell_size( + comm.all_reduce(uv_data.v.cwiseAbs().maxCoeff(), MPI_MAX), params.height(), + params.oversampling()); + } +#endif + PURIFY_HIGH_LOG( + "Using cell size {}\" x {}\", recommended from the uv coverage and field of view is " + "{}\"x{}\".", + params.cellsizey(), params.cellsizex(), ideal_cell_y, ideal_cell_x); + PURIFY_HIGH_LOG("The equivalent miriad cell size is: {}\" x {}\"", + widefield::equivalent_miriad_cell_size(params.cellsizex(), params.width(), + params.oversampling()), + widefield::equivalent_miriad_cell_size(params.cellsizey(), params.height(), + params.oversampling())); + + return {uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks}; +} + +measurementOpInfo createMeasurementOperator(const YamlParser ¶ms, + const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, + const bool using_mpi, + const std::vector &image_index, + const std::vector &w_stacks, + const utilities::vis_params &uv_data, + Vector &measurement_op_eigen_vector) +{ + std::shared_ptr>> measurements_transform; + if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and + mop_algo != factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all) + measurements_transform = + (not params.wprojection()) + ? factory::measurement_operator_factory>( + mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), + params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jx(), + params.mpi_wstacking()) + : factory::measurement_operator_factory>( + mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), + params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jw(), + params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); + else + measurements_transform = + (not params.wprojection()) + ? factory::all_to_all_measurement_operator_factory>( + mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), + params.cellsizey(), params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jx(), + params.mpi_wstacking()) + : factory::all_to_all_measurement_operator_factory>( + mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), + params.cellsizey(), params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jw(), + params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); + t_real operator_norm = 1.; +#ifdef PURIFY_MPI + if (using_mpi) { + auto const comm = sopt::mpi::Communicator::World(); + auto power_method_result = + (params.mpiAlgorithm() != factory::algo_distribution::mpi_random_updates) + ? sopt::algorithm::power_method>( + *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(), + comm.broadcast(measurement_op_eigen_vector).eval()) + : sopt::algorithm::all_sum_all_power_method>( + comm, *measurements_transform, params.powMethod_iter(), + params.powMethod_tolerance(), comm.broadcast(measurement_op_eigen_vector).eval()); + measurement_op_eigen_vector = std::get<1>(power_method_result); + operator_norm = std::get<0>(power_method_result); + } else +#endif + { + auto power_method_result = sopt::algorithm::power_method>( + *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(), + measurement_op_eigen_vector); + measurement_op_eigen_vector = std::get<1>(power_method_result); + operator_norm = std::get<0>(power_method_result); + } + + return {measurements_transform, operator_norm}; +} + +void initOutDirectoryWithConfig(YamlParser ¶ms) +{ + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifdef PURIFY_MPI + auto const world = sopt::mpi::Communicator::World(); + if (world.is_root()) +#else + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + params.writeOutput(); + } else { + params.writeOutput(); + } +} + +Headers genHeaders(const YamlParser ¶ms, const utilities::vis_params &uv_data) +{ + const pfitsio::header_params update_header_sol = + pfitsio::header_params(params.output_path() + "/sol_update.fits", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, + params.measurements_polarization(), params.cellsizex(), + params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); + const pfitsio::header_params update_header_res = + pfitsio::header_params(params.output_path() + "/res_update.fits", "Jy/Beam", 1, uv_data.ra, uv_data.dec, + params.measurements_polarization(), params.cellsizex(), + params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); + const pfitsio::header_params def_header = pfitsio::header_params( + "", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, params.measurements_polarization(), + params.cellsizex(), params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); + + return {update_header_sol, update_header_res, def_header}; +} + +void saveMeasurementEigenVector(const YamlParser ¶ms, const Vector &measurement_op_eigen_vector) +{ + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifdef PURIFY_MPI + auto const world = sopt::mpi::Communicator::World(); + if (world.is_root()) +#else + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + { + pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), + params.output_path() + "/eigenvector_real.fits", "pix", true); + pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), + params.output_path() + "/eigenvector_imag.fits", "pix", true); + } + } else { + pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), + params.output_path() + "/eigenvector_real.fits", "pix", true); + pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), + params.output_path() + "/eigenvector_imag.fits", "pix", true); + } +} + +void savePSF(const YamlParser ¶ms, + const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, + const t_real flux_scale, + const t_real sigma, + const t_real operator_norm, + const t_real beam_units) +{ + pfitsio::header_params psf_header = def_header; + psf_header.fits_name = params.output_path() + "/psf.fits"; + psf_header.pix_units = "Jy/Pixel"; + const Vector psf = measurements_transform->adjoint() * (uv_data.weights / flux_scale); + const Image psf_image = + Image::Map(psf.data(), params.height(), params.width()).real(); + PURIFY_HIGH_LOG( + "Peak of PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", + psf_image(static_cast(params.width() * 0.5 + params.height() * 0.5 * params.width()))); + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifdef PURIFY_MPI + auto const world = sopt::mpi::Communicator::World(); + PURIFY_LOW_LOG( + "Expected image domain residual RMS is {} jy/beam", + sigma * params.epsilonScaling() * operator_norm / + (std::sqrt(params.width() * params.height()) * world.all_sum_all(uv_data.size()))); + if (world.is_root()) +#else + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + pfitsio::write2d(psf_image, psf_header, true); + } else { + PURIFY_LOW_LOG("Expected image domain residual RMS is {} jy/beam", + sigma * params.epsilonScaling() * operator_norm / + (std::sqrt(params.width() * params.height()) * uv_data.size())); + pfitsio::write2d(psf_image, psf_header, true); + } + PURIFY_HIGH_LOG( + "Theoretical calculation for peak PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", + beam_units); + PURIFY_HIGH_LOG("Effective sigma is {} Jy", sigma * params.epsilonScaling()); +} + +void saveDirtyImage(const YamlParser ¶ms, + const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, + const t_real beam_units) +{ + pfitsio::header_params dirty_header = def_header; + dirty_header.fits_name = params.output_path() + "/dirty.fits"; + dirty_header.pix_units = "Jy/Beam"; + const Vector dimage = measurements_transform->adjoint() * uv_data.vis; + const Image dirty_image = + Image::Map(dimage.data(), params.height(), params.width()).real(); + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifdef PURIFY_MPI + auto const world = sopt::mpi::Communicator::World(); + if (world.is_root()) +#else + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + pfitsio::write2d(dirty_image / beam_units, dirty_header, true); + } else { + pfitsio::write2d(dirty_image / beam_units, dirty_header, true); + } +} \ No newline at end of file diff --git a/cpp/purify/setup_utils.h b/cpp/purify/setup_utils.h new file mode 100644 index 000000000..1e597a9ac --- /dev/null +++ b/cpp/purify/setup_utils.h @@ -0,0 +1,83 @@ +#include "purify/types.h" +#include "purify/measurement_operator_factory.h" +#include "purify/wavelet_operator_factory.h" +#include "purify/pfitsio.h" +#include "purify/read_measurements.h" +#include "purify/yaml-parser.h" +#include "purify/logging.h" + +using namespace purify; + +struct waveletInfo +{ + std::shared_ptr> transform; + t_uint sara_size; +}; + +waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed_wavelet_operator &wop_algo); + +struct OperatorsInfo +{ + factory::distributed_measurement_operator mop_algo; + factory::distributed_wavelet_operator wop_algo; + bool using_mpi; +}; + +OperatorsInfo selectOperators(YamlParser ¶ms); + +struct inputData +{ + utilities::vis_params uv_data; + t_real sigma; + Vector measurement_op_eigen_vector; + std::vector image_index; + std::vector w_stacks; +}; + +inputData getInputData(const YamlParser ¶ms, + const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, + const bool using_mpi); + +struct measurementOpInfo +{ + std::shared_ptr>> measurement_transform; + t_real operator_norm; +}; + +measurementOpInfo createMeasurementOperator(const YamlParser ¶ms, + const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, + const bool using_mpi, + const std::vector &image_index, + const std::vector &w_stacks, + const utilities::vis_params &uv_data, + Vector &measurement_op_eigen_vector); + +void initOutDirectoryWithConfig(YamlParser ¶ms); + +struct Headers +{ + pfitsio::header_params solution_header; + pfitsio::header_params residuals_header; + pfitsio::header_params def_header; +}; + +Headers genHeaders(const YamlParser ¶ms, const utilities::vis_params &uv_data); + +void saveMeasurementEigenVector(const YamlParser ¶ms, const Vector &measurement_op_eigen_vector); + +void savePSF(const YamlParser ¶ms, + const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, + const t_real flux_scale, + const t_real sigma, + const t_real operator_norm, + const t_real beam_units); + +void saveDirtyImage(const YamlParser ¶ms, + const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, + const t_real beam_units); From 0920091a657caa11e72b56238f2cfc6713b3ed4c Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Sun, 1 Oct 2023 20:12:56 +0100 Subject: [PATCH 13/32] Add Purify config file read to UQ --- cpp/uncertainty_quantification/uq_main.cc | 98 ++++++++++++++++------- 1 file changed, 70 insertions(+), 28 deletions(-) diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index 67665f0f9..295fd88b0 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -6,6 +6,7 @@ #include #include "yaml-cpp/yaml.h" #include "purify/yaml-parser.h" +#include "purify/setup_utils.h" using VectorC = sopt::Vector>; @@ -24,10 +25,74 @@ int main(int argc, char **argv) const std::string config_path = argv[1]; const YAML::Node UQ_config = YAML::LoadFile(config_path); - const std::string measurements_path = UQ_config["measurements_path"].as(); + // Load the Reference and Surrogate images const std::string ref_image_path = UQ_config["reference_image_path"].as(); const std::string surrogate_image_path = UQ_config["surrogate_image_path"].as(); - const std::string purify_config_path = UQ_config["purify_config_path"].as(); + const auto reference_image = purify::pfitsio::read2d(ref_image_path); + const VectorC reference_vector = VectorC::Map(reference_image.data(), reference_image.size()); + const auto surrogate_image = purify::pfitsio::read2d(surrogate_image_path); + const VectorC surrogate_vector = VectorC::Map(surrogate_image.data(), surrogate_image.size()); + + + const uint imsize_x = reference_image.cols(); + const uint imsize_y = reference_image.rows(); + + // Prepare operators and data using either purify config + // If no purify config use basic version for now + purify::utilities::vis_params measurement_data; + std::shared_ptr>> measurement_operator; + if(UQ_config["purify_config_file"]) + { + YamlParser purify_config = YamlParser(UQ_config["purify_config_file"].as()); + + const auto [mop_algo, wop_algo, using_mpi] = selectOperators(purify_config); + auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = getInputData(purify_config, + mop_algo, + wop_algo, + using_mpi); + + auto [transform, operator_norm] = createMeasurementOperator(purify_config, + mop_algo, + wop_algo, + using_mpi, + image_index, + w_stacks, + uv_data, + measurement_op_eigen_vector); + + const waveletInfo wavelets = createWaveletOperator(purify_config, wop_algo); + + t_real const flux_scale = 1.; + uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale; + + measurement_data = uv_data; + measurement_operator = transform; + } + else + { + const std::string measurements_path = UQ_config["measurements_path"].as(); + // Load the images and measurements + const purify::utilities::vis_params uv_data = purify::utilities::read_visibility(measurements_path, false); + + // This is the measurement operator used in the test but this should probably be selectable + auto const transform = + purify::factory::measurement_operator_factory>>( + purify::factory::distributed_measurement_operator::serial, + uv_data, + imsize_y, + imsize_x, + 1, + 1, + 2, + kernels::kernel_from_string.at("kb"), + 4, + 4); + + measurement_operator = transform; + measurement_data = uv_data; + } + + // Set up confidence and objective function params double confidence; double alpha; if((UQ_config["confidence_interval"]) && (UQ_config["alpha"])) @@ -52,19 +117,7 @@ int main(int argc, char **argv) } const double sigma = UQ_config["sigma"].as(); const double gamma = UQ_config["gamma"].as(); - purify::YamlParser purify_config = purify::YamlParser(purify_config_path); - - // Load the images and measurements - const purify::utilities::vis_params measurement_data = purify::utilities::read_visibility(measurements_path, false); - - const auto reference_image = purify::pfitsio::read2d(ref_image_path); - const VectorC reference_vector = VectorC::Map(reference_image.data(), reference_image.size()); - const auto surrogate_image = purify::pfitsio::read2d(surrogate_image_path); - const VectorC surrogate_vector = VectorC::Map(surrogate_image.data(), surrogate_image.size()); - - const uint imsize_x = reference_image.cols(); - const uint imsize_y = reference_image.rows(); if((imsize_x != surrogate_image.cols()) || (imsize_y != surrogate_image.rows())) { @@ -72,26 +125,15 @@ int main(int argc, char **argv) return 2; } - // This is the measurement operator used in the test but this should probably be selectable - auto const measurement_operator = - purify::factory::measurement_operator_factory>>( - purify::factory::distributed_measurement_operator::serial, - measurement_data, - imsize_y, - imsize_x, - purify_config.cellsizey(), - purify_config.cellsizex(), - purify_config.oversampling(), - purify::kernels::kernel_from_string.at(purify_config.kernel()), - purify_config.sim_J() - ); - + if( ((*measurement_operator) * reference_vector).size() != measurement_data.vis.size()) { std::cout << "Image size is not compatible with the measurement operator and data provided." << std::endl; return 3; } + + // Calculate the posterior function for the reference image // posterior = likelihood + prior // Likelihood = |y - Phi(x)|^2 / sigma^2 (L2 norm) From efcefa89a9533983a07c2dea510c037b007d2e0f Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Mon, 2 Oct 2023 17:31:24 +0100 Subject: [PATCH 14/32] Add wavelet operator to prior --- cpp/uncertainty_quantification/uq_main.cc | 32 ++++++++++++++++------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index 295fd88b0..2fb18a60f 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -7,6 +7,9 @@ #include "yaml-cpp/yaml.h" #include "purify/yaml-parser.h" #include "purify/setup_utils.h" +#include +#include +#include using VectorC = sopt::Vector>; @@ -38,9 +41,14 @@ int main(int argc, char **argv) const uint imsize_y = reference_image.rows(); // Prepare operators and data using either purify config - // If no purify config use basic version for now + // If no purify config use basic version for now based on algo_factory test images purify::utilities::vis_params measurement_data; - std::shared_ptr>> measurement_operator; + std::shared_ptr> measurement_operator; + std::shared_ptr> wavelet_operator; + std::vector> const sara{ + std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), + std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), + std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)}; if(UQ_config["purify_config_file"]) { YamlParser purify_config = YamlParser(UQ_config["purify_config_file"].as()); @@ -67,18 +75,19 @@ int main(int argc, char **argv) measurement_data = uv_data; measurement_operator = transform; + wavelet_operator = wavelets.transform; } else { const std::string measurements_path = UQ_config["measurements_path"].as(); // Load the images and measurements - const purify::utilities::vis_params uv_data = purify::utilities::read_visibility(measurements_path, false); + measurement_data = purify::utilities::read_visibility(measurements_path, false); // This is the measurement operator used in the test but this should probably be selectable - auto const transform = + measurement_operator = purify::factory::measurement_operator_factory>>( purify::factory::distributed_measurement_operator::serial, - uv_data, + measurement_data, imsize_y, imsize_x, 1, @@ -88,8 +97,8 @@ int main(int argc, char **argv) 4, 4); - measurement_operator = transform; - measurement_data = uv_data; + wavelet_operator = purify::factory::wavelet_operator_factory>( + factory::distributed_wavelet_operator::serial, sara, imsize_y, imsize_x); } // Set up confidence and objective function params @@ -137,11 +146,14 @@ int main(int argc, char **argv) // Calculate the posterior function for the reference image // posterior = likelihood + prior // Likelihood = |y - Phi(x)|^2 / sigma^2 (L2 norm) - // Prior = Sum(|x_i|) * gamma (L1 norm) - auto Posterior = [&measurement_data, measurement_operator, sigma, gamma](const VectorC &image) { + // Prior = Sum(Psi^t * |x_i|) * gamma (L1 norm) + auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, sigma, gamma](const VectorC &image) { { const auto residuals = (*measurement_operator * image) - measurement_data.vis; - return residuals.squaredNorm() / (2 * sigma * sigma) + image.cwiseAbs().sum() * gamma; + auto likelihood = residuals.squaredNorm() / (2 * sigma * sigma); + const VectorC wavelet_rep = wavelet_operator->adjoint() * image; + auto prior = wavelet_rep.cwiseAbs().sum() * gamma; + return likelihood + prior; } }; From da5ef51e0f74bfe4d1d0c13a9d232d6870b30172 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Mon, 2 Oct 2023 18:33:38 +0100 Subject: [PATCH 15/32] Update readme to include UQ --- README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/README.md b/README.md index 31103b020..2a8ddde66 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,32 @@ When `purify` runs a directory will be created, and the output images will be saved and time-stamped. Additionally, a config file with the settings used will be saved and time-stamped, helping for reproducibility and book-keeping. +Uncertainty Quantification +-------------------------- + +Bayesian hypothesis testing may be performed with the `purify_UQ` application, which can be located in the `build` directory. + +This application takes a config yaml file with the following parameters: +- `confidence_interval` or `alpha`. (alpha = 1 - confidence interval.)) +- `measurements_path`: path to measurements data (.vis file) +- `reference_image_path`: path to reference image i.e. output from purify run. (.fits files) +- `surrogate_image_path`: path to surrogate image i.e. doctored image with blurring or structural change that you want to test. (.fits file) +- `sigma`: standard deviation for Gaussian likelihood. +- `gamma`: multiplicative factor for prior. +- `purify_config`: path to purify config used to generate the reference image. **This should be used if available in order to ensure consistency of things like measurement and wavelet operators.** + +You can then run the uncertainty quantification with the command: +``` +purify_UQ +``` + +The application will report the value of the objective function for each image, the threshold value calculated from the reference image, and whether the surrogate image is ruled out or not. + +Presently this is designed to work for the unconstrained problem where: +- The negative log-likelihood is a scaled L2-norm i.e. $ \frac{1}{2 \sigma^2} \sum (y_i - \Phi x_i)$ for some data $y$, image $x$, and measurement operator $\Phi$. (Equivalent to indepdendent multivariate Gaussian likelihood.) +- The negative log prior is a scaled L1-norm in _wavelet space_ i.e. $\gamma \sum (\Psi^\dag x)_i$ for some image $x$ and wavelet operator $\Psi$. +- The objective function is the sum of these two terms. + Docker ------- From 693681a88e612cccf6348ef38150d02d2fb8397e Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 7 Nov 2024 10:40:01 +0000 Subject: [PATCH 16/32] Add generic cost function --- cpp/uncertainty_quantification/uq_main.cc | 31 ++++++++++++++--------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index 2fb18a60f..439737b2c 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -2,7 +2,10 @@ #include "purify/pfitsio.h" #include "purify/utilities.h" #include "purify/measurement_operator_factory.h" +#include "purify/setup_utils.h" #include "sopt/objective_functions.h" +#include "sopt/differentiable_func.h" +#include "sopt/non_differentiable_func.h" #include #include "yaml-cpp/yaml.h" #include "purify/yaml-parser.h" @@ -15,10 +18,13 @@ using VectorC = sopt::Vector>; int main(int argc, char **argv) { - if(argc != 2) + if(argc != 4) { - std::cout << "purify_UQ should be run using a single additional argument, which is the path to the config (yaml) file." << std::endl; - std::cout << "purify_UQ " << std::endl; + std::cout << "purify_UQ should be run using three additional arguments." << std::endl; + std::cout << "purify_UQ " << std::endl; + std::cout << ": path to a .yaml config file specifying details of measurement operator, wavelet operator, observations, and cost functions." << std::endl; + std::cout << ": path to image file (.fits) which was output from running purify on observed data." << std::endl; + std::cout << ": path to modified image file (.fits) for feature analysis." << std::endl; std::cout << std::endl; std::cout << "For more information about the contents of the config file please consult the README." << std::endl; return 1; @@ -29,8 +35,8 @@ int main(int argc, char **argv) const YAML::Node UQ_config = YAML::LoadFile(config_path); // Load the Reference and Surrogate images - const std::string ref_image_path = UQ_config["reference_image_path"].as(); - const std::string surrogate_image_path = UQ_config["surrogate_image_path"].as(); + const std::string ref_image_path = argv[2]; + const std::string surrogate_image_path = argv[3]; const auto reference_image = purify::pfitsio::read2d(ref_image_path); const VectorC reference_vector = VectorC::Map(reference_image.data(), reference_image.size()); const auto surrogate_image = purify::pfitsio::read2d(surrogate_image_path); @@ -85,7 +91,7 @@ int main(int argc, char **argv) // This is the measurement operator used in the test but this should probably be selectable measurement_operator = - purify::factory::measurement_operator_factory>>( + purify::factory::measurement_operator_factory>( purify::factory::distributed_measurement_operator::serial, measurement_data, imsize_y, @@ -141,19 +147,20 @@ int main(int argc, char **argv) return 3; } - + std::unique_ptr> f; + std::unique_ptr> g; + // set up f and g from config // Calculate the posterior function for the reference image // posterior = likelihood + prior // Likelihood = |y - Phi(x)|^2 / sigma^2 (L2 norm) // Prior = Sum(Psi^t * |x_i|) * gamma (L1 norm) - auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, sigma, gamma](const VectorC &image) { + auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, sigma, gamma, &f, &g](const VectorC &image) { { const auto residuals = (*measurement_operator * image) - measurement_data.vis; - auto likelihood = residuals.squaredNorm() / (2 * sigma * sigma); - const VectorC wavelet_rep = wavelet_operator->adjoint() * image; - auto prior = wavelet_rep.cwiseAbs().sum() * gamma; - return likelihood + prior; + auto A = f->function(image, measurement_data.vis, (*measurement_operator)); + auto B = g->function(image); + return A + gamma * B; } }; From 9f0923b26adc0dfab76102582e28584eb811b80d Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Wed, 13 Nov 2024 15:01:32 +0000 Subject: [PATCH 17/32] Remove some confusing root-2s --- cpp/purify/algorithm_factory.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/purify/algorithm_factory.h b/cpp/purify/algorithm_factory.h index a279ebf6f..e0e0f3e65 100644 --- a/cpp/purify/algorithm_factory.h +++ b/cpp/purify/algorithm_factory.h @@ -174,8 +174,8 @@ fb_factory(const algo_distribution dist, auto fb = std::make_shared(uv_data.vis); fb->itermax(max_iterations) .gamma(reg_parameter) - .sigma(sigma * std::sqrt(2)) - .beta(step_size * std::sqrt(2)) + .sigma(sigma) + .beta(step_size) .relative_variation(relative_variation) .tight_frame(tight_frame) .nu(op_norm * op_norm) From f1a1982d9ce90a2627842e3a7cce71a7ea941fce Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 14 Nov 2024 16:36:50 +0000 Subject: [PATCH 18/32] Move non-templated functions to cpp file! --- cpp/purify/pfitsio.cc | 26 ++++++++++++++++++++++++++ cpp/purify/pfitsio.h | 24 ++---------------------- cpp/purify/setup_utils.h | 5 +++++ 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/cpp/purify/pfitsio.cc b/cpp/purify/pfitsio.cc index 857b2a77d..5a5733e2b 100644 --- a/cpp/purify/pfitsio.cc +++ b/cpp/purify/pfitsio.cc @@ -91,4 +91,30 @@ void write3d(const std::vector> &eigen_images, const std::string & write3d(eigen_images, header, overwrite); } +//! Read cube from fits file +std::vector> read3d(const std::string &fits_name) { + std::vector> eigen_images; + Vector image; + int rows, cols, channels, pols = 1; + read3d>(fits_name, image, rows, cols, channels, pols); + for (int i = 0; i < channels; i++) { + Vector eigen_image = Vector::Zero(rows * cols); + eigen_image.real() = image.segment(i * rows * cols, rows * cols); + eigen_images.push_back(Image::Map(eigen_image.data(), rows, cols)); + } + return eigen_images; +} + +//! Read image from fits file +Image read2d(const std::string &fits_name) { + /* + Reads in an image from a fits file and returns the image. + + fits_name:: name of fits file + */ + + const std::vector> images = read3d(fits_name); + return images.at(0); +} + } // namespace purify::pfitsio diff --git a/cpp/purify/pfitsio.h b/cpp/purify/pfitsio.h index 4312417c4..6cbc2fa02 100644 --- a/cpp/purify/pfitsio.h +++ b/cpp/purify/pfitsio.h @@ -322,30 +322,10 @@ void read3d(const std::string &fits_name, Eigen::EigenBase &output, int &rows } //! Read cube from fits file -std::vector> read3d(const std::string &fits_name) { - std::vector> eigen_images; - Vector image; - int rows, cols, channels, pols = 1; - read3d>(fits_name, image, rows, cols, channels, pols); - for (int i = 0; i < channels; i++) { - Vector eigen_image = Vector::Zero(rows * cols); - eigen_image.real() = image.segment(i * rows * cols, rows * cols); - eigen_images.push_back(Image::Map(eigen_image.data(), rows, cols)); - } - return eigen_images; -} +std::vector> read3d(const std::string &fits_name); //! Read image from fits file -Image read2d(const std::string &fits_name) { - /* - Reads in an image from a fits file and returns the image. - - fits_name:: name of fits file - */ - - const std::vector> images = read3d(fits_name); - return images.at(0); -} +Image read2d(const std::string &fits_name); } // namespace purify::pfitsio diff --git a/cpp/purify/setup_utils.h b/cpp/purify/setup_utils.h index 1e597a9ac..e0f614ed2 100644 --- a/cpp/purify/setup_utils.h +++ b/cpp/purify/setup_utils.h @@ -1,3 +1,6 @@ +#ifndef SETUP_UTILS_H +#define SETUP_UTILS_H + #include "purify/types.h" #include "purify/measurement_operator_factory.h" #include "purify/wavelet_operator_factory.h" @@ -81,3 +84,5 @@ void saveDirtyImage(const YamlParser ¶ms, const std::shared_ptr>> &measurements_transform, const utilities::vis_params &uv_data, const t_real beam_units); + +#endif \ No newline at end of file From 48d3dd84bbcecbc95301f3c092814f939018ea79 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 14 Nov 2024 16:37:15 +0000 Subject: [PATCH 19/32] Spacing --- cpp/uncertainty_quantification/uq_main.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index 439737b2c..acdf884eb 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -42,7 +42,6 @@ int main(int argc, char **argv) const auto surrogate_image = purify::pfitsio::read2d(surrogate_image_path); const VectorC surrogate_vector = VectorC::Map(surrogate_image.data(), surrogate_image.size()); - const uint imsize_x = reference_image.cols(); const uint imsize_y = reference_image.rows(); From 95703117e0b168eab0e573888bdded5471f52536 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 14:08:49 +0000 Subject: [PATCH 20/32] Use non-greek sopt interface --- cpp/purify/algorithm_factory.h | 20 ++++++++++---------- cpp/purify/update_factory.h | 16 ++++++++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/cpp/purify/algorithm_factory.h b/cpp/purify/algorithm_factory.h index a279ebf6f..805d8f242 100644 --- a/cpp/purify/algorithm_factory.h +++ b/cpp/purify/algorithm_factory.h @@ -80,7 +80,7 @@ padmm_factory(const algo_distribution dist, .l1_proximal_positivity_constraint(positive_constraint) .l1_proximal_real_constraint(real_constraint) .lagrange_update_scale(0.9) - .nu(op_norm * op_norm) + .sq_op_norm(op_norm * op_norm) .Psi(*wavelets) .Phi(*measurements); #ifdef PURIFY_MPI @@ -90,7 +90,7 @@ padmm_factory(const algo_distribution dist, switch (dist) { case (algo_distribution::serial): padmm - ->gamma((wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval()) + ->regulariser_strength((wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval()) .cwiseAbs() .maxCoeff() * 1e-3) @@ -132,8 +132,8 @@ padmm_factory(const algo_distribution dist, std::weak_ptr const padmm_weak(padmm); // set epsilon padmm->residual_tolerance(epsilon * residual_tolerance_scaling).l2ball_proximal_epsilon(epsilon); - // set gamma - padmm->gamma(comm.all_reduce( + // set regulariser_strength + padmm->regulariser_strength(comm.all_reduce( utilities::step_size>(uv_data.vis, measurements, wavelets, sara_size) * 1e-3, MPI_MAX)); @@ -173,12 +173,12 @@ fb_factory(const algo_distribution dist, "one wavelet basis."); auto fb = std::make_shared(uv_data.vis); fb->itermax(max_iterations) - .gamma(reg_parameter) + .regulariser_strength(reg_parameter) .sigma(sigma * std::sqrt(2)) - .beta(step_size * std::sqrt(2)) + .step_size(step_size * std::sqrt(2)) .relative_variation(relative_variation) .tight_frame(tight_frame) - .nu(op_norm * op_norm) + .sq_op_norm(op_norm * op_norm) .Phi(*measurements); if (f_function) fb->f_function(f_function); // only override f_function default if non-null @@ -282,7 +282,7 @@ primaldual_factory( switch (dist) { case (algo_distribution::serial): { primaldual - ->gamma((wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval()) + ->regulariser_strength((wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval()) .cwiseAbs() .maxCoeff() * 1e-3) @@ -345,8 +345,8 @@ primaldual_factory( // set epsilon primaldual->residual_tolerance(epsilon * residual_tolerance_scaling) .l2ball_proximal_epsilon(epsilon); - // set gamma - primaldual->gamma(comm.all_reduce( + // set regulariser_strength + primaldual->regulariser_strength(comm.all_reduce( utilities::step_size>(uv_data.vis, measurements, wavelets, sara_size) * 1e-3, MPI_MAX)); diff --git a/cpp/purify/update_factory.h b/cpp/purify/update_factory.h index 3c4ec6b5f..6c4bc2fab 100644 --- a/cpp/purify/update_factory.h +++ b/cpp/purify/update_factory.h @@ -43,17 +43,17 @@ void add_updater(std::weak_ptr const algo_weak, const t_real step_size_sca step_size_scale, update_header_sol, update_header_res, sara_size, comm, beam_units](const Vector &x, const Vector &res) -> bool { auto algo = algo_weak.lock(); - if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ {}", algo->gamma()); - if (algo->gamma() > 0) { + if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ {}", algo->regulariser_strength()); + if (algo->regulariser_strength() > 0) { Vector const alpha = algo->Psi().adjoint() * x; const t_real new_gamma = comm.all_reduce((sara_size > 0) ? alpha.real().cwiseAbs().maxCoeff() : 0., MPI_MAX) * step_size_scale; if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma); // updating parameter - algo->gamma(((std::abs(algo->gamma() - new_gamma) > update_tol) and *iter < update_iters) + algo->regulariser_strength(((std::abs(algo->regulariser_strength() - new_gamma) > update_tol) and *iter < update_iters) ? new_gamma - : algo->gamma()); + : algo->regulariser_strength()); } Vector const residual = algo->Phi().adjoint() * (res / beam_units).eval(); PURIFY_MEDIUM_LOG("RMS of residual map in Jy/beam {}", @@ -86,16 +86,16 @@ void add_updater(std::weak_ptr const algo_weak, const t_real step_size_sca #endif ](const Vector &x, const Vector &res) -> bool { auto algo = algo_weak.lock(); - if (algo->gamma() > 0) { - PURIFY_MEDIUM_LOG("Step size γ {}", algo->gamma()); + if (algo->regulariser_strength() > 0) { + PURIFY_MEDIUM_LOG("Step size γ {}", algo->regulariser_strength()); Vector const alpha = algo->Psi().adjoint() * x; const t_real new_gamma = alpha.real().cwiseAbs().maxCoeff() * step_size_scale; PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma); // updating parameter - algo->gamma(((std::abs((algo->gamma() - new_gamma) / algo->gamma()) > update_tol) and + algo->regulariser_strength(((std::abs((algo->regulariser_strength() - new_gamma) / algo->regulariser_strength()) > update_tol) and *iter < update_iters) ? new_gamma - : algo->gamma()); + : algo->regulariser_strength()); } Vector const residual = algo->Phi().adjoint() * (res / beam_units).eval(); PURIFY_MEDIUM_LOG("RMS of residual map in Jy/beam {}", From 7ec30b4640efbbb213eae6ef39bc938b0fe18ab1 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 14:13:03 +0000 Subject: [PATCH 21/32] Lintin' --- cpp/purify/algorithm_factory.h | 18 ++++++++++-------- cpp/purify/update_factory.h | 17 ++++++++++------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/cpp/purify/algorithm_factory.h b/cpp/purify/algorithm_factory.h index 805d8f242..49177849f 100644 --- a/cpp/purify/algorithm_factory.h +++ b/cpp/purify/algorithm_factory.h @@ -90,10 +90,11 @@ padmm_factory(const algo_distribution dist, switch (dist) { case (algo_distribution::serial): padmm - ->regulariser_strength((wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval()) - .cwiseAbs() - .maxCoeff() * - 1e-3) + ->regulariser_strength( + (wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval()) + .cwiseAbs() + .maxCoeff() * + 1e-3) .l2ball_proximal_epsilon(epsilon) .residual_tolerance(epsilon * residual_tolerance_scaling); return padmm; @@ -282,10 +283,11 @@ primaldual_factory( switch (dist) { case (algo_distribution::serial): { primaldual - ->regulariser_strength((wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval()) - .cwiseAbs() - .maxCoeff() * - 1e-3) + ->regulariser_strength( + (wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval()) + .cwiseAbs() + .maxCoeff() * + 1e-3) .l2ball_proximal_epsilon(epsilon) .residual_tolerance(epsilon * residual_tolerance_scaling); return primaldual; diff --git a/cpp/purify/update_factory.h b/cpp/purify/update_factory.h index 6c4bc2fab..0983d667f 100644 --- a/cpp/purify/update_factory.h +++ b/cpp/purify/update_factory.h @@ -51,9 +51,11 @@ void add_updater(std::weak_ptr const algo_weak, const t_real step_size_sca step_size_scale; if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma); // updating parameter - algo->regulariser_strength(((std::abs(algo->regulariser_strength() - new_gamma) > update_tol) and *iter < update_iters) - ? new_gamma - : algo->regulariser_strength()); + algo->regulariser_strength( + ((std::abs(algo->regulariser_strength() - new_gamma) > update_tol) and + *iter < update_iters) + ? new_gamma + : algo->regulariser_strength()); } Vector const residual = algo->Phi().adjoint() * (res / beam_units).eval(); PURIFY_MEDIUM_LOG("RMS of residual map in Jy/beam {}", @@ -92,10 +94,11 @@ void add_updater(std::weak_ptr const algo_weak, const t_real step_size_sca const t_real new_gamma = alpha.real().cwiseAbs().maxCoeff() * step_size_scale; PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma); // updating parameter - algo->regulariser_strength(((std::abs((algo->regulariser_strength() - new_gamma) / algo->regulariser_strength()) > update_tol) and - *iter < update_iters) - ? new_gamma - : algo->regulariser_strength()); + algo->regulariser_strength(((std::abs((algo->regulariser_strength() - new_gamma) / + algo->regulariser_strength()) > update_tol) and + *iter < update_iters) + ? new_gamma + : algo->regulariser_strength()); } Vector const residual = algo->Phi().adjoint() * (res / beam_units).eval(); PURIFY_MEDIUM_LOG("RMS of residual map in Jy/beam {}", From 1d350842763a9af856209af3b3f453911e4a844a Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 15:23:31 +0000 Subject: [PATCH 22/32] Unified approach for f/g function types --- cpp/purify/algorithm_factory.h | 13 ++++++------- cpp/purify/types.h | 6 ++++++ cpp/purify/yaml-parser.cc | 1 + cpp/purify/yaml-parser.h | 2 +- cpp/tests/algo_factory.cc | 4 ++-- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/cpp/purify/algorithm_factory.h b/cpp/purify/algorithm_factory.h index 12c3c7fd5..360f5b6f0 100644 --- a/cpp/purify/algorithm_factory.h +++ b/cpp/purify/algorithm_factory.h @@ -35,14 +35,13 @@ namespace purify { namespace factory { enum class algorithm { padmm, primal_dual, sdmm, forward_backward }; enum class algo_distribution { serial, mpi_serial, mpi_distributed, mpi_random_updates }; -enum class g_proximal_type { L1GProximal, TFGProximal, Indicator }; const std::map algo_distribution_string = { {"none", algo_distribution::serial}, {"serial-equivalent", algo_distribution::mpi_serial}, {"random-updates", algo_distribution::mpi_random_updates}, {"fully-distributed", algo_distribution::mpi_distributed}}; -const std::map g_proximal_type_string = { - {"l1", g_proximal_type::L1GProximal}, {"learned", g_proximal_type::TFGProximal}}; +const std::map g_proximal_type_string = { + {"l1", nondiff_func_type::L1Norm}, {"denoiser", nondiff_func_type::Denoiser}, {"realIndicator", nondiff_func_type::RealIndicator}}; //! return chosen algorithm given parameters template @@ -165,7 +164,7 @@ fb_factory(const algo_distribution dist, const bool tight_frame = false, const t_real relative_variation = 1e-3, const t_real l1_proximal_tolerance = 1e-2, const t_uint maximum_proximal_iterations = 50, const t_real op_norm = 1, const std::string model_path = "", - const g_proximal_type g_proximal = g_proximal_type::L1GProximal, + const nondiff_func_type g_proximal = nondiff_func_type::L1Norm, std::shared_ptr> f_function = nullptr) { typedef typename Algorithm::Scalar t_scalar; if (sara_size > 1 and tight_frame) @@ -186,7 +185,7 @@ fb_factory(const algo_distribution dist, std::shared_ptr> g; switch (g_proximal) { - case (g_proximal_type::L1GProximal): { + case (nondiff_func_type::L1Norm): { // Create a shared pointer to an instance of the L1GProximal class // and set its properties auto l1_gp = std::make_shared>(false); @@ -206,7 +205,7 @@ fb_factory(const algo_distribution dist, g = l1_gp; break; } - case (g_proximal_type::TFGProximal): { + case (nondiff_func_type::Denoiser): { #ifdef PURIFY_ONNXRT // Create a shared pointer to an instance of the TFGProximal class g = std::make_shared>(model_path); @@ -216,7 +215,7 @@ fb_factory(const algo_distribution dist, "Type TFGProximal not recognized because purify was built with onnxrt=off"); #endif } - case (g_proximal_type::Indicator): { + case (nondiff_func_type::RealIndicator): { g = std::make_shared>(); break; } diff --git a/cpp/purify/types.h b/cpp/purify/types.h index 7754f983b..065744978 100644 --- a/cpp/purify/types.h +++ b/cpp/purify/types.h @@ -22,6 +22,12 @@ typedef std::complex t_complexf; //! Root of the type hierarchy for triplet lists typedef Eigen::Triplet t_tripletList; +// Different available types of differentiable functions (f) +enum class diff_func_type {L2Norm, L2Norm_with_CRR}; + +// Different available types of non-differentiable functions (g) +enum class nondiff_func_type {L1Norm, Denoiser, RealIndicator}; + //! \brief A matrix of a given type //! \details Operates as mathematical sparse matrix. template diff --git a/cpp/purify/yaml-parser.cc b/cpp/purify/yaml-parser.cc index 0a3a11f7b..bea5cdd7c 100644 --- a/cpp/purify/yaml-parser.cc +++ b/cpp/purify/yaml-parser.cc @@ -224,6 +224,7 @@ void YamlParser::parseAndSetAlgorithmOptions(const YAML::Node& algorithmOptionsN get(algorithmOptionsNode, {"fb", "dualFBVarianceConvergence"}); this->gProximalType_ = factory::g_proximal_type_string.at( get(algorithmOptionsNode, {"fb", "gProximalType"})); + this->model_path_ = get(algorithmOptionsNode, {"fb", "modelPath"}); if (this->algorithm_ == "fb_joint_map") { this->jmap_iters_ = diff --git a/cpp/purify/yaml-parser.h b/cpp/purify/yaml-parser.h index d5e2b40bf..533700ba7 100644 --- a/cpp/purify/yaml-parser.h +++ b/cpp/purify/yaml-parser.h @@ -141,7 +141,7 @@ class YamlParser { YAML_MACRO(t_real, jmap_beta, 1) YAML_MACRO(std::string, model_path, "") - YAML_MACRO(factory::g_proximal_type, gProximalType, factory::g_proximal_type::L1GProximal) + YAML_MACRO(nondiff_func_type, gProximalType, nondiff_func_type::L1Norm) std::string output_path() const { diff --git a/cpp/tests/algo_factory.cc b/cpp/tests/algo_factory.cc index 1e7b1ac48..5c8b0dd55 100644 --- a/cpp/tests/algo_factory.cc +++ b/cpp/tests/algo_factory.cc @@ -223,7 +223,7 @@ TEST_CASE("tf_fb_factory") { auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, - tf_model_path, factory::g_proximal_type::TFGProximal); + tf_model_path, nondiff_func_type::Denoiser); auto const diagnostic = (*fb)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); @@ -285,7 +285,7 @@ TEST_CASE("onnx_fb_factory") { auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, "", - factory::g_proximal_type::Indicator, diff_function); + nondiff_func_type::RealIndicator, diff_function); auto const diagnostic = (*fb)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); From 734abcfaf320a49fcb23f6916823b7c964092ecf Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 15:37:56 +0000 Subject: [PATCH 23/32] Remove gProximalType references --- cpp/main.cc | 2 +- cpp/purify/algorithm_factory.h | 2 +- cpp/purify/yaml-parser.cc | 3 +-- cpp/purify/yaml-parser.h | 3 ++- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/main.cc b/cpp/main.cc index 596e5d472..7a6314b27 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -127,7 +127,7 @@ int main(int argc, const char **argv) { (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and (not params.positiveValueConstraint()), params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50, operator_norm, - params.model_path(), params.gProximalType()); + params.model_path(), params.nondiffFuncType()); if (params.algorithm() == "primaldual") primaldual = factory::primaldual_factory>( params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data, diff --git a/cpp/purify/algorithm_factory.h b/cpp/purify/algorithm_factory.h index 360f5b6f0..d11993d06 100644 --- a/cpp/purify/algorithm_factory.h +++ b/cpp/purify/algorithm_factory.h @@ -40,7 +40,7 @@ const std::map algo_distribution_string = { {"serial-equivalent", algo_distribution::mpi_serial}, {"random-updates", algo_distribution::mpi_random_updates}, {"fully-distributed", algo_distribution::mpi_distributed}}; -const std::map g_proximal_type_string = { +const std::map nondiff_type_string = { {"l1", nondiff_func_type::L1Norm}, {"denoiser", nondiff_func_type::Denoiser}, {"realIndicator", nondiff_func_type::RealIndicator}}; //! return chosen algorithm given parameters diff --git a/cpp/purify/yaml-parser.cc b/cpp/purify/yaml-parser.cc index bea5cdd7c..3a6cab1cd 100644 --- a/cpp/purify/yaml-parser.cc +++ b/cpp/purify/yaml-parser.cc @@ -222,9 +222,8 @@ void YamlParser::parseAndSetAlgorithmOptions(const YAML::Node& algorithmOptionsN get(algorithmOptionsNode, {"fb", "regularisation_parameter"}); this->dualFBVarianceConvergence_ = get(algorithmOptionsNode, {"fb", "dualFBVarianceConvergence"}); - this->gProximalType_ = factory::g_proximal_type_string.at( + this->nondiffFuncType_ = factory::nondiff_type_string.at( get(algorithmOptionsNode, {"fb", "gProximalType"})); - this->model_path_ = get(algorithmOptionsNode, {"fb", "modelPath"}); if (this->algorithm_ == "fb_joint_map") { this->jmap_iters_ = diff --git a/cpp/purify/yaml-parser.h b/cpp/purify/yaml-parser.h index 533700ba7..4b958b1cd 100644 --- a/cpp/purify/yaml-parser.h +++ b/cpp/purify/yaml-parser.h @@ -141,7 +141,8 @@ class YamlParser { YAML_MACRO(t_real, jmap_beta, 1) YAML_MACRO(std::string, model_path, "") - YAML_MACRO(nondiff_func_type, gProximalType, nondiff_func_type::L1Norm) + YAML_MACRO(nondiff_func_type, nondiffFuncType, nondiff_func_type::L1Norm) + YAML_MACRO(diff_func_type, diffFuncType, diff_func_type::L2Norm) std::string output_path() const { From 773bd912ccc300b206af1df2efca44e92ff65dae Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 15:39:54 +0000 Subject: [PATCH 24/32] Move type map into types header --- cpp/purify/algorithm_factory.h | 2 -- cpp/purify/types.h | 2 ++ cpp/purify/yaml-parser.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/purify/algorithm_factory.h b/cpp/purify/algorithm_factory.h index d11993d06..3a6876590 100644 --- a/cpp/purify/algorithm_factory.h +++ b/cpp/purify/algorithm_factory.h @@ -40,8 +40,6 @@ const std::map algo_distribution_string = { {"serial-equivalent", algo_distribution::mpi_serial}, {"random-updates", algo_distribution::mpi_random_updates}, {"fully-distributed", algo_distribution::mpi_distributed}}; -const std::map nondiff_type_string = { - {"l1", nondiff_func_type::L1Norm}, {"denoiser", nondiff_func_type::Denoiser}, {"realIndicator", nondiff_func_type::RealIndicator}}; //! return chosen algorithm given parameters template diff --git a/cpp/purify/types.h b/cpp/purify/types.h index 065744978..56ab90c0b 100644 --- a/cpp/purify/types.h +++ b/cpp/purify/types.h @@ -27,6 +27,8 @@ enum class diff_func_type {L2Norm, L2Norm_with_CRR}; // Different available types of non-differentiable functions (g) enum class nondiff_func_type {L1Norm, Denoiser, RealIndicator}; +const std::map nondiff_type_string = { + {"l1", nondiff_func_type::L1Norm}, {"denoiser", nondiff_func_type::Denoiser}, {"realIndicator", nondiff_func_type::RealIndicator}}; //! \brief A matrix of a given type //! \details Operates as mathematical sparse matrix. diff --git a/cpp/purify/yaml-parser.cc b/cpp/purify/yaml-parser.cc index 3a6cab1cd..118d9499a 100644 --- a/cpp/purify/yaml-parser.cc +++ b/cpp/purify/yaml-parser.cc @@ -222,7 +222,7 @@ void YamlParser::parseAndSetAlgorithmOptions(const YAML::Node& algorithmOptionsN get(algorithmOptionsNode, {"fb", "regularisation_parameter"}); this->dualFBVarianceConvergence_ = get(algorithmOptionsNode, {"fb", "dualFBVarianceConvergence"}); - this->nondiffFuncType_ = factory::nondiff_type_string.at( + this->nondiffFuncType_ = nondiff_type_string.at( get(algorithmOptionsNode, {"fb", "gProximalType"})); this->model_path_ = get(algorithmOptionsNode, {"fb", "modelPath"}); if (this->algorithm_ == "fb_joint_map") { From 632ecb98c8204d3a0628a2753962b90e3cdc83b3 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 15:49:15 +0000 Subject: [PATCH 25/32] Add differentiable function to yaml --- cpp/purify/types.h | 6 +++++- cpp/purify/yaml-parser.cc | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/cpp/purify/types.h b/cpp/purify/types.h index 56ab90c0b..e44814c63 100644 --- a/cpp/purify/types.h +++ b/cpp/purify/types.h @@ -24,11 +24,15 @@ typedef Eigen::Triplet t_tripletList; // Different available types of differentiable functions (f) enum class diff_func_type {L2Norm, L2Norm_with_CRR}; +const std::map diff_type_string = { + {"l2", diff_func_type::L2Norm}, {"CRR", diff_func_type::L2Norm_with_CRR} +}; // Different available types of non-differentiable functions (g) enum class nondiff_func_type {L1Norm, Denoiser, RealIndicator}; const std::map nondiff_type_string = { - {"l1", nondiff_func_type::L1Norm}, {"denoiser", nondiff_func_type::Denoiser}, {"realIndicator", nondiff_func_type::RealIndicator}}; + {"l1", nondiff_func_type::L1Norm}, {"denoiser", nondiff_func_type::Denoiser}, {"realIndicator", nondiff_func_type::RealIndicator} +}; //! \brief A matrix of a given type //! \details Operates as mathematical sparse matrix. diff --git a/cpp/purify/yaml-parser.cc b/cpp/purify/yaml-parser.cc index 118d9499a..05e5dc7d1 100644 --- a/cpp/purify/yaml-parser.cc +++ b/cpp/purify/yaml-parser.cc @@ -223,7 +223,9 @@ void YamlParser::parseAndSetAlgorithmOptions(const YAML::Node& algorithmOptionsN this->dualFBVarianceConvergence_ = get(algorithmOptionsNode, {"fb", "dualFBVarianceConvergence"}); this->nondiffFuncType_ = nondiff_type_string.at( - get(algorithmOptionsNode, {"fb", "gProximalType"})); + get(algorithmOptionsNode, {"fb", "nonDifferentiableFunctionType"})); + this->diffFuncType_ = diff_type_string.at( + get(algorithmOptionsNode, {"fb", "differentiableFunctionType"})); this->model_path_ = get(algorithmOptionsNode, {"fb", "modelPath"}); if (this->algorithm_ == "fb_joint_map") { this->jmap_iters_ = From 2dd9c4c8ef6df601e84922056e77b3bd272894bb Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 16:07:09 +0000 Subject: [PATCH 26/32] Optional yaml parsing --- cpp/purify/yaml-parser.cc | 15 ++++++++++++++- cpp/purify/yaml-parser.h | 4 ++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/cpp/purify/yaml-parser.cc b/cpp/purify/yaml-parser.cc index 05e5dc7d1..716055108 100644 --- a/cpp/purify/yaml-parser.cc +++ b/cpp/purify/yaml-parser.cc @@ -222,11 +222,24 @@ void YamlParser::parseAndSetAlgorithmOptions(const YAML::Node& algorithmOptionsN get(algorithmOptionsNode, {"fb", "regularisation_parameter"}); this->dualFBVarianceConvergence_ = get(algorithmOptionsNode, {"fb", "dualFBVarianceConvergence"}); + this->nondiffFuncType_ = nondiff_type_string.at( get(algorithmOptionsNode, {"fb", "nonDifferentiableFunctionType"})); + if(this->nondiffFuncType_ == nondiff_func_type::Denoiser) + { + this->model_path_ = get(algorithmOptionsNode, {"fb", "modelPath"}); + } + this->diffFuncType_ = diff_type_string.at( get(algorithmOptionsNode, {"fb", "differentiableFunctionType"})); - this->model_path_ = get(algorithmOptionsNode, {"fb", "modelPath"}); + if(this->diffFuncType_ == diff_func_type::L2Norm_with_CRR) + { + this->CRR_function_model_path_ = get(algorithmOptionsNode, {"fb", "CRR_function_model_path"}); + this->CRR_gradient_model_path_ = get(algorithmOptionsNode, {"fb", "CRR_gradient_model_path"}); + this->CRR_mu = get(algorithmOptionsNode, {"fb", "CRR_mu"}); + this->CRR_lambda = get(algorithmOptionsNode, {"fb", "CRR_lambda"}); + } + if (this->algorithm_ == "fb_joint_map") { this->jmap_iters_ = get(algorithmOptionsNode, {"fb", "joint_map_estimation", "iters"}); diff --git a/cpp/purify/yaml-parser.h b/cpp/purify/yaml-parser.h index 4b958b1cd..daa3d5f2a 100644 --- a/cpp/purify/yaml-parser.h +++ b/cpp/purify/yaml-parser.h @@ -143,6 +143,10 @@ class YamlParser { YAML_MACRO(std::string, model_path, "") YAML_MACRO(nondiff_func_type, nondiffFuncType, nondiff_func_type::L1Norm) YAML_MACRO(diff_func_type, diffFuncType, diff_func_type::L2Norm) + YAML_MACRO(std::string, CRR_function_model_path, "") + YAML_MACRO(std::string, CRR_gradient_model_path, "") + YAML_MACRO(t_real, CRR_mu, 20) + YAML_MACRO(t_real, CRR_lambda, 5000) std::string output_path() const { From 0c11af24159aeeb1eedeaadd7d5f10802fcbc7a1 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 16:08:56 +0000 Subject: [PATCH 27/32] syntax errors --- cpp/purify/yaml-parser.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/purify/yaml-parser.cc b/cpp/purify/yaml-parser.cc index 716055108..34fe67dc4 100644 --- a/cpp/purify/yaml-parser.cc +++ b/cpp/purify/yaml-parser.cc @@ -236,8 +236,8 @@ void YamlParser::parseAndSetAlgorithmOptions(const YAML::Node& algorithmOptionsN { this->CRR_function_model_path_ = get(algorithmOptionsNode, {"fb", "CRR_function_model_path"}); this->CRR_gradient_model_path_ = get(algorithmOptionsNode, {"fb", "CRR_gradient_model_path"}); - this->CRR_mu = get(algorithmOptionsNode, {"fb", "CRR_mu"}); - this->CRR_lambda = get(algorithmOptionsNode, {"fb", "CRR_lambda"}); + this->CRR_mu_ = get(algorithmOptionsNode, {"fb", "CRR_mu"}); + this->CRR_lambda_ = get(algorithmOptionsNode, {"fb", "CRR_lambda"}); } if (this->algorithm_ == "fb_joint_map") { From 277b8cadcae8e2e0388ff03add2fce4eaa9b7414 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 17:56:54 +0000 Subject: [PATCH 28/32] Add construction of f and g --- cpp/purify/algorithm_factory.h | 2 +- cpp/uncertainty_quantification/uq_main.cc | 60 +++++++++++++++++++---- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/cpp/purify/algorithm_factory.h b/cpp/purify/algorithm_factory.h index 3a6876590..965e20aeb 100644 --- a/cpp/purify/algorithm_factory.h +++ b/cpp/purify/algorithm_factory.h @@ -214,7 +214,7 @@ fb_factory(const algo_distribution dist, #endif } case (nondiff_func_type::RealIndicator): { - g = std::make_shared>(); + g = std::make_shared>(); break; } default: { diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index acdf884eb..2b26184e5 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -13,6 +13,12 @@ #include #include #include +#include +#include +#include +#include +#include + using VectorC = sopt::Vector>; @@ -45,7 +51,12 @@ int main(int argc, char **argv) const uint imsize_x = reference_image.cols(); const uint imsize_y = reference_image.rows(); - // Prepare operators and data using either purify config + std::unique_ptr> f; + std::unique_ptr> g; + + double sigma; + + // Prepare operators and data using purify config // If no purify config use basic version for now based on algo_factory test images purify::utilities::vis_params measurement_data; std::shared_ptr> measurement_operator; @@ -81,6 +92,39 @@ int main(int argc, char **argv) measurement_data = uv_data; measurement_operator = transform; wavelet_operator = wavelets.transform; + + // set up f and g from config + switch (purify_config.diffFuncType()) + { + case purify::diff_func_type::L2Norm: + f = std::make_unique>(sigma, *measurement_operator); + break; + case purify::diff_func_type::L2Norm_with_CRR: + f = std::make_unique>( + purify_config.CRR_function_model_path(), + purify_config.CRR_gradient_model_path(), + sigma, + purify_config.CRR_mu(), + purify_config.CRR_lambda(), + *measurement_operator + ); + break; + } + + switch (purify_config.nondiffFuncType()) + { + case purify::nondiff_func_type::L1Norm: + g = std::make_unique>(); + break; + case purify::nondiff_func_type::Denoiser: + g = std::make_unique>( + purify_config.model_path() + ); + break; + case purify::nondiff_func_type::RealIndicator: + g = std::make_unique>(); + break; + } } else { @@ -129,8 +173,8 @@ int main(int argc, char **argv) std::cout << "Config file must contain either 'confidence_interval' or 'alpha' as a parameter." << std::endl; return 1; } - const double sigma = UQ_config["sigma"].as(); - const double gamma = UQ_config["gamma"].as(); + + const double regulariser_strength = UQ_config["regulariser_strength"].as(); if((imsize_x != surrogate_image.cols()) || (imsize_y != surrogate_image.rows())) @@ -146,20 +190,16 @@ int main(int argc, char **argv) return 3; } - std::unique_ptr> f; - std::unique_ptr> g; - // set up f and g from config - // Calculate the posterior function for the reference image // posterior = likelihood + prior // Likelihood = |y - Phi(x)|^2 / sigma^2 (L2 norm) - // Prior = Sum(Psi^t * |x_i|) * gamma (L1 norm) - auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, sigma, gamma, &f, &g](const VectorC &image) { + // Prior = Sum(Psi^t * |x_i|) * regulariser_strength (L1 norm) + auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, sigma, regulariser_strength, &f, &g](const VectorC &image) { { const auto residuals = (*measurement_operator * image) - measurement_data.vis; auto A = f->function(image, measurement_data.vis, (*measurement_operator)); auto B = g->function(image); - return A + gamma * B; + return A + regulariser_strength * B; } }; From 6af78c785602464b308f901cc1f9c786be77461a Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 17:59:39 +0000 Subject: [PATCH 29/32] Fix sigma shadowing --- cpp/uncertainty_quantification/uq_main.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index 2b26184e5..a09862c79 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -54,8 +54,6 @@ int main(int argc, char **argv) std::unique_ptr> f; std::unique_ptr> g; - double sigma; - // Prepare operators and data using purify config // If no purify config use basic version for now based on algo_factory test images purify::utilities::vis_params measurement_data; @@ -194,7 +192,7 @@ int main(int argc, char **argv) // posterior = likelihood + prior // Likelihood = |y - Phi(x)|^2 / sigma^2 (L2 norm) // Prior = Sum(Psi^t * |x_i|) * regulariser_strength (L1 norm) - auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, sigma, regulariser_strength, &f, &g](const VectorC &image) { + auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, regulariser_strength, &f, &g](const VectorC &image) { { const auto residuals = (*measurement_operator * image) - measurement_data.vis; auto A = f->function(image, measurement_data.vis, (*measurement_operator)); From 7bfd22b0daaca09dc26ac210dfdc619d708d0c2c Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 18:11:38 +0000 Subject: [PATCH 30/32] Refactor function setup for re-use --- cpp/purify/setup_utils.cc | 37 +++++++++++++++++++++++ cpp/purify/setup_utils.h | 6 ++++ cpp/uncertainty_quantification/uq_main.cc | 34 ++------------------- 3 files changed, 45 insertions(+), 32 deletions(-) diff --git a/cpp/purify/setup_utils.cc b/cpp/purify/setup_utils.cc index 0721075b6..32f88839c 100644 --- a/cpp/purify/setup_utils.cc +++ b/cpp/purify/setup_utils.cc @@ -1,5 +1,12 @@ #include "purify/setup_utils.h" #include +#include +#include +#include +#include +#include +#include +#include using namespace purify; @@ -299,6 +306,36 @@ measurementOpInfo createMeasurementOperator(const YamlParser ¶ms, return {measurements_transform, operator_norm}; } +void setupCostFunctions(const YamlParser ¶ms, + std::unique_ptr> &f, + std::unique_ptr> &g, + t_real sigma, + sopt::LinearTransform> &Phi) +{ + switch (params.diffFuncType()) { + case purify::diff_func_type::L2Norm: + f = std::make_unique>(sigma, Phi); + break; + case purify::diff_func_type::L2Norm_with_CRR: + f = std::make_unique>( + params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma, params.CRR_mu(), + params.CRR_lambda(), Phi); + break; + } + + switch (params.nondiffFuncType()) { + case purify::nondiff_func_type::L1Norm: + g = std::make_unique>(); + break; + case purify::nondiff_func_type::Denoiser: + g = std::make_unique>(params.model_path()); + break; + case purify::nondiff_func_type::RealIndicator: + g = std::make_unique>(); + break; + } +} + void initOutDirectoryWithConfig(YamlParser ¶ms) { if (params.mpiAlgorithm() != factory::algo_distribution::serial) { diff --git a/cpp/purify/setup_utils.h b/cpp/purify/setup_utils.h index e0f614ed2..d5ef9c0b3 100644 --- a/cpp/purify/setup_utils.h +++ b/cpp/purify/setup_utils.h @@ -57,6 +57,12 @@ measurementOpInfo createMeasurementOperator(const YamlParser ¶ms, const utilities::vis_params &uv_data, Vector &measurement_op_eigen_vector); +void setupCostFunctions(const YamlParser ¶ms, + std::unique_ptr> &f, + std::unique_ptr> &g, + t_real sigma, + sopt::LinearTransform> &Phi); + void initOutDirectoryWithConfig(YamlParser ¶ms); struct Headers diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index a09862c79..fd0fdb134 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -91,38 +91,8 @@ int main(int argc, char **argv) measurement_operator = transform; wavelet_operator = wavelets.transform; - // set up f and g from config - switch (purify_config.diffFuncType()) - { - case purify::diff_func_type::L2Norm: - f = std::make_unique>(sigma, *measurement_operator); - break; - case purify::diff_func_type::L2Norm_with_CRR: - f = std::make_unique>( - purify_config.CRR_function_model_path(), - purify_config.CRR_gradient_model_path(), - sigma, - purify_config.CRR_mu(), - purify_config.CRR_lambda(), - *measurement_operator - ); - break; - } - - switch (purify_config.nondiffFuncType()) - { - case purify::nondiff_func_type::L1Norm: - g = std::make_unique>(); - break; - case purify::nondiff_func_type::Denoiser: - g = std::make_unique>( - purify_config.model_path() - ); - break; - case purify::nondiff_func_type::RealIndicator: - g = std::make_unique>(); - break; - } + // setup f and g based on config file + setupCostFunctions(purify_config, f, g, sigma, *measurement_operator); } else { From f089c0738546c8dcb15a57cf71d9230d5c68af14 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 18:20:13 +0000 Subject: [PATCH 31/32] Default to avoid null pointers --- cpp/uncertainty_quantification/uq_main.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index fd0fdb134..d912c2b13 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -116,6 +116,10 @@ int main(int argc, char **argv) wavelet_operator = purify::factory::wavelet_operator_factory>( factory::distributed_wavelet_operator::serial, sara, imsize_y, imsize_x); + + // default cost function + f = std::make_unique>(1, *measurement_operator); // what would a default sigma look like?? + g = std::make_unique>(); } // Set up confidence and objective function params From 65cdc87606e70aeb5b511d96f01f8e208559c69d Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 21 Nov 2024 18:42:02 +0000 Subject: [PATCH 32/32] add onnxrt to main --- cpp/main.cc | 24 ++++++++++++++++++++++- cpp/purify/setup_utils.h | 2 ++ cpp/uncertainty_quantification/uq_main.cc | 4 +++- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/cpp/main.cc b/cpp/main.cc index 7a6314b27..52d4dc249 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -21,6 +21,11 @@ #include #include #include "purify/setup_utils.h" +#include + +#ifdef PURIFY_ONNXRT +#include +#endif using namespace purify; @@ -118,6 +123,22 @@ int main(int argc, const char **argv) { params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50, params.epsilonConvergenceScaling(), operator_norm); if (params.algorithm() == "fb") + { + std::shared_ptr> f; + if(params.diffFuncType() == diff_func_type::L2Norm_with_CRR) + { + #ifdef PURIFY_ONNXRT + f = std::make_shared>(params.CRR_function_model_path(), + params.CRR_gradient_model_path(), + sigma, + params.CRR_mu(), + params.CRR_lambda(), + *measurements_transform); + #else + throw std::runtime_error("CRR approach cannot be used with ONNXRT off"); + #endif + } + fb = factory::fb_factory>( params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data, sigma * params.epsilonScaling() / flux_scale, @@ -127,7 +148,8 @@ int main(int argc, const char **argv) { (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and (not params.positiveValueConstraint()), params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50, operator_norm, - params.model_path(), params.nondiffFuncType()); + params.model_path(), params.nondiffFuncType(), f); + } if (params.algorithm() == "primaldual") primaldual = factory::primaldual_factory>( params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data, diff --git a/cpp/purify/setup_utils.h b/cpp/purify/setup_utils.h index d5ef9c0b3..fdce82db2 100644 --- a/cpp/purify/setup_utils.h +++ b/cpp/purify/setup_utils.h @@ -8,6 +8,8 @@ #include "purify/read_measurements.h" #include "purify/yaml-parser.h" #include "purify/logging.h" +#include +#include using namespace purify; diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index d912c2b13..3a638d5b0 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -17,8 +17,10 @@ #include #include #include -#include +#ifdef PURIFY_ONNXRT +#include +#endif using VectorC = sopt::Vector>;