Skip to content

Commit

Permalink
Add local to csr_eq_segment
Browse files Browse the repository at this point in the history
  • Loading branch information
Xewar313 committed Nov 20, 2024
1 parent 55185dc commit b7704ea
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
3 changes: 2 additions & 1 deletion include/dr/mp/algorithms/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ inline auto dpl_reduce(rng::forward_range auto &&r, auto &&binary_op) {
sycl::known_identity_v<Fn, T>, binary_op);
} else {
dr::drlog.debug(" peel 1st value\n");
auto base = *rng::begin(r);
return std::reduce(dpl_policy(),
dr::__detail::direct_iterator(rng::begin(r) + 1),
dr::__detail::direct_iterator(rng::end(r)),
sycl_get(*rng::begin(r)), binary_op);
sycl_get(base), binary_op);
}
}
#else
Expand Down
30 changes: 30 additions & 0 deletions include/dr/mp/containers/matrix_formats/csr_eq_segment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,36 @@ template <typename DSM> class csr_eq_segment_iterator {
return dr::__detail::drop_segments(dsm_->segments(), segment_index_,
index_);
}

auto local() const {
const auto my_process_segment_index = dsm_->rows_backend_.getrank();

assert(my_process_segment_index == segment_index_);
// auto offset = dsm_->row_offsets_[segment_index_];
// auto row_size = dsm_->row_size_;
auto segment_size = dsm_->vals_data_->segment_size();
auto local_vals = dsm_->vals_data_->segments()[segment_index_].begin().local();
auto local_vals_range = rng::subrange(local_vals, local_vals + segment_size);
auto local_cols = dsm_->cols_data_->segments()[segment_index_].begin().local();
auto local_cols_range = rng::subrange(local_cols, local_cols + segment_size);
// auto local_rows = dsm_->rows_data_;
auto zipped_results = rng::views::zip(local_vals_range, local_cols_range);
auto enumerated_zipped = rng::views::enumerate(zipped_results);
auto transformer = [&](auto entry) {
auto [index, pair] = entry;
auto [val, column] = pair;
auto row = 0; //TODO fix calculating row - it results in segfault
// auto row = rng::distance(
// local_rows,
// std::upper_bound(local_rows, local_rows + row_size, offset + index) -
// 1);
dr::index<index_type> index_obj(row, column);
value_type entry_obj(index_obj, val);
return entry_obj;
};
auto transformed_res = rng::transform_view(enumerated_zipped, transformer);
return transformed_res.begin();
}

private:
// all fields need to be initialized by default ctor so every default
Expand Down

0 comments on commit b7704ea

Please sign in to comment.