Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
acdemiralp committed Nov 21, 2021
2 parents bf6a9b5 + c3d4454 commit 8a08826
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 15 deletions.
1 change: 1 addition & 0 deletions include/mpi/all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@

#include <mpi/extensions/detach.hpp>
#include <mpi/extensions/future.hpp>
#include <mpi/extensions/shared_variable.hpp>

#include <mpi/io/enums/access_mode.hpp>
#include <mpi/io/enums/seek_mode.hpp>
Expand Down
26 changes: 13 additions & 13 deletions include/mpi/core/window.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class window
explicit window (const communicator& communicator, const std::int64_t size, const std::int32_t displacement_unit = 1, const bool shared = false, const information& information = mpi::information())
: managed_(true)
{
void* base_pointer; // Unused. Call base_pointer() explicitly.
void* base_pointer; // Unused. Call base_pointer() explicitly.
if (shared)
MPI_CHECK_ERROR_CODE(MPI_Win_allocate_shared, (size, displacement_unit, information.native(), communicator.native(), &base_pointer, &native_))
else
Expand Down Expand Up @@ -89,7 +89,7 @@ class window

// A static member function for construction is bad practice but constructors do not support templates if the type does not appear in the arguments.
template <typename type, typename = std::enable_if_t<!std::is_same_v<type, void>>>
static window allocate (const communicator& communicator, const std::int64_t size, const bool shared = false, const information& information = mpi::information())
static window allocate (const communicator& communicator, const std::int64_t size = 1, const bool shared = false, const information& information = mpi::information())
{
window result;
void* base_pointer; // Unused. Call base_pointer() explicitly.
Expand Down Expand Up @@ -157,7 +157,7 @@ class window
MPI_CHECK_ERROR_CODE(MPI_Win_call_errhandler, (native_, value.native()))
}

template <typename type>
template <typename type> [[nodiscard]]
std::optional<type> attribute (const window_key_value& key) const
{
type result;
Expand Down Expand Up @@ -285,13 +285,13 @@ class window
}

// Remote memory access operations.
void get (void* source , const std::int32_t source_size, const data_type& source_data_type,
void get ( void* source , const std::int32_t source_size, const data_type& source_data_type,
const std::int32_t target_rank, const std::int64_t target_displacement, const std::int32_t target_size, const data_type& target_data_type) const
{
MPI_CHECK_ERROR_CODE(MPI_Get, (source, source_size, source_data_type.native(), target_rank, target_displacement, target_size, target_data_type.native(), native_))
}
template <typename type>
void get (const type& source ,
void get ( type& source ,
const std::int32_t target_rank, const std::int64_t target_displacement, const std::int32_t target_size, const data_type& target_data_type) const
{
using adapter = container_adapter<type>;
Expand Down Expand Up @@ -332,7 +332,7 @@ class window
}
template <typename source_type, typename result_type>
void get_accumulate (const source_type& source ,
const result_type& result ,
result_type& result ,
const std::int32_t target_rank, const std::int64_t target_displacement, const std::int32_t target_size, const data_type& target_data_type, const op& op = ops::sum) const
{
using source_adapter = container_adapter<source_type>;
Expand All @@ -343,38 +343,38 @@ class window
target_rank, target_displacement, target_size, target_data_type, op);
}

void fetch_and_op (const void* source, void* result, const data_type& data_type, const std::int32_t target_rank, const std::int64_t target_displacement, const op& op = ops::sum) const
void fetch_and_op (const void* source, void* result, const data_type& data_type, const std::int32_t target_rank, const std::int64_t target_displacement, const op& op = ops::sum) const
{
MPI_CHECK_ERROR_CODE(MPI_Fetch_and_op, (source, result, data_type.native(), target_rank, target_displacement, op.native(), native_))
}
template <typename type>
void fetch_and_op (const type& source, const type& result, const std::int32_t target_rank, const std::int64_t target_displacement, const op& op = ops::sum) const
void fetch_and_op (const type& source, type& result , const std::int32_t target_rank, const std::int64_t target_displacement, const op& op = ops::sum) const
{
using adapter = container_adapter<type>;
fetch_and_op(static_cast<const void*>(adapter::data(source)), static_cast<void*>(adapter::data(result)), adapter::data_type(), target_rank, target_displacement, op);
}

void compare_and_swap (const void* source, const void* compare, void* result, const data_type& data_type, const std::int32_t target_rank, const std::int64_t target_displacement) const
void compare_and_swap (const void* source, const void* compare, void* result, const data_type& data_type, const std::int32_t target_rank, const std::int64_t target_displacement) const
{
MPI_CHECK_ERROR_CODE(MPI_Compare_and_swap, (source, compare, result, data_type.native(), target_rank, target_displacement, native_))
}
template <typename type>
void compare_and_swap (const type& source, const type& compare, const type& result, const std::int32_t target_rank, const std::int64_t target_displacement) const
void compare_and_swap (const type& source, const type& compare, type& result , const std::int32_t target_rank, const std::int64_t target_displacement) const
{
using adapter = container_adapter<type>;
compare_and_swap(static_cast<const void*>(adapter::data(source)), static_cast<const void*>(adapter::data(compare)), static_cast<void*>(adapter::data(result)), adapter::data_type(), target_rank, target_displacement);
}

// Request remote memory access operations.
request request_get (void* source , const std::int32_t source_size, const data_type& source_data_type,
request request_get ( void* source , const std::int32_t source_size, const data_type& source_data_type,
const std::int32_t target_rank, const std::int64_t target_displacement, const std::int32_t target_size, const data_type& target_data_type) const
{
request result(MPI_REQUEST_NULL, true);
MPI_CHECK_ERROR_CODE(MPI_Rget, (source, source_size, source_data_type.native(), target_rank, target_displacement, target_size, target_data_type.native(), native_, &result.native_))
return result;
}
template <typename type>
request request_get (const type& source ,
request request_get ( type& source ,
const std::int32_t target_rank, const std::int64_t target_displacement, const std::int32_t target_size, const data_type& target_data_type) const
{
using adapter = container_adapter<type>;
Expand Down Expand Up @@ -421,7 +421,7 @@ class window
}
template <typename source_type, typename result_type>
request request_get_accumulate(const source_type& source ,
const result_type& result ,
result_type& result ,
const std::int32_t target_rank, const std::int64_t target_displacement, const std::int32_t target_size, const data_type& target_data_type, const op& op = ops::sum) const
{
using source_adapter = container_adapter<source_type>;
Expand Down
99 changes: 99 additions & 0 deletions include/mpi/extensions/shared_variable.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#pragma once

#include <cstdint>

#include <mpi/core/communicators/communicator.hpp>
#include <mpi/core/type/type_traits.hpp>
#include <mpi/core/window.hpp>

namespace mpi
{
// The type must be compliant.
// Requires .synchronize() to be called explicitly to ensure synchronization.
template <typename type>
class manual_shared_variable
{
public:
explicit manual_shared_variable (const communicator& communicator, const std::int32_t root = 0)
: communicator_(communicator), root_(root), window_(communicator_, communicator_.rank() == root_ ? type_traits<type>::get_data_type().size() : 0)
{

}
manual_shared_variable (const manual_shared_variable& that) = delete ;
manual_shared_variable ( manual_shared_variable&& temp) = default;
virtual ~manual_shared_variable () = default;
manual_shared_variable& operator=(const manual_shared_variable& that) = delete ;
manual_shared_variable& operator=( manual_shared_variable&& temp) = default;
manual_shared_variable& operator=(const type& value)
{
set(value);
return *this;
}
operator type() const
{
return get();
}

void set (const type& value)
{
window_.lock (root_, false);
window_.put (value , root_, 0, 1, type_traits<type>::get_data_type());
window_.unlock(root_);
}
[[nodiscard]]
type get () const
{
type result {};
window_.lock (root_, true);
window_.get (result, root_, 0, 1, type_traits<type>::get_data_type());
window_.unlock(root_);
return result;
}

void synchronize() const
{
window_.fence();
}

protected:
const communicator& communicator_;
std::int32_t root_ ;
window window_ ;
};

// The type must be compliant.
// Calls .synchronize() globally after a .set_if_rank(...), ensuring synchronization.
// Note that .set_if_rank(...) must be called on all processes in the communicator as if its a collective.
template <typename type>
class automatic_shared_variable : public manual_shared_variable<type>
{
public:
using base = manual_shared_variable<type>;

explicit automatic_shared_variable (const communicator& communicator, const std::int32_t root = 0)
: base(communicator, root)
{

}
automatic_shared_variable (const automatic_shared_variable& that) = delete ;
automatic_shared_variable ( automatic_shared_variable&& temp) = default;
~automatic_shared_variable () override = default;
automatic_shared_variable& operator=(const automatic_shared_variable& that) = delete ;
automatic_shared_variable& operator=( automatic_shared_variable&& temp) = default;
automatic_shared_variable& operator=(const type& value) = delete ;
operator type() const
{
return base::get();
}

void set_if_rank(const type& value, const std::int32_t rank)
{
if (base::communicator_.rank() == rank)
base::set(value);
base::synchronize();
}
};

template <typename type>
using shared_variable = automatic_shared_variable<type>;
}
2 changes: 0 additions & 2 deletions tests/future_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#include <iostream>

#include "internal/doctest.h"

#define MPI_USE_EXCEPTIONS
Expand Down
58 changes: 58 additions & 0 deletions tests/shared_variable_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "internal/doctest.h"

#define MPI_USE_EXCEPTIONS

#include <mpi/all.hpp>

TEST_CASE("Shared Variable Test")
{
mpi::environment environment ;
const auto& communicator = mpi::world_communicator;

{
mpi::manual_shared_variable<std::int32_t> variable(communicator);

if (communicator.rank() == 0)
variable = 42;

variable.synchronize(); // :(

if (communicator.rank() != 0)
REQUIRE(variable == 42);
}

{
mpi::manual_shared_variable<std::array<std::int32_t, 3>> variable(communicator);
if (communicator.rank() == 0)
variable = {1, 2, 3};

variable.synchronize(); // :(

if (communicator.rank() != 0)
{
auto value = variable.get();
REQUIRE(value[0] == 1);
REQUIRE(value[1] == 2);
REQUIRE(value[2] == 3);
}
}

{
mpi::shared_variable<std::int32_t> variable(communicator);
variable.set_if_rank(42, 0);
if (communicator.rank() != 0)
REQUIRE(variable == 42);
}

{
mpi::shared_variable<std::array<std::int32_t, 3>> variable(communicator);
variable.set_if_rank({1, 2, 3}, 0);
if (communicator.rank() != 0)
{
auto value = variable.get();
REQUIRE(value[0] == 1);
REQUIRE(value[1] == 2);
REQUIRE(value[2] == 3);
}
}
}

0 comments on commit 8a08826

Please sign in to comment.