diff --git a/include/dr/mp/algorithms/reduce.hpp b/include/dr/mp/algorithms/reduce.hpp index 839dc83a4a..166abfd44a 100644 --- a/include/dr/mp/algorithms/reduce.hpp +++ b/include/dr/mp/algorithms/reduce.hpp @@ -35,10 +35,11 @@ inline auto dpl_reduce(rng::forward_range auto &&r, auto &&binary_op) { sycl::known_identity_v, binary_op); } else { dr::drlog.debug(" peel 1st value\n"); + auto base = *rng::begin(r); return std::reduce(dpl_policy(), dr::__detail::direct_iterator(rng::begin(r) + 1), dr::__detail::direct_iterator(rng::end(r)), - sycl_get(*rng::begin(r)), binary_op); + sycl_get(base), binary_op); } } #else diff --git a/include/dr/mp/containers/matrix_formats/csr_eq_segment.hpp b/include/dr/mp/containers/matrix_formats/csr_eq_segment.hpp index f08c1dc4ef..527dc82ffb 100644 --- a/include/dr/mp/containers/matrix_formats/csr_eq_segment.hpp +++ b/include/dr/mp/containers/matrix_formats/csr_eq_segment.hpp @@ -235,6 +235,36 @@ template class csr_eq_segment_iterator { return dr::__detail::drop_segments(dsm_->segments(), segment_index_, index_); } + + auto local() const { + const auto my_process_segment_index = dsm_->rows_backend_.getrank(); + + assert(my_process_segment_index == segment_index_); + // auto offset = dsm_->row_offsets_[segment_index_]; + // auto row_size = dsm_->row_size_; + auto segment_size = dsm_->vals_data_->segment_size(); + auto local_vals = dsm_->vals_data_->segments()[segment_index_].begin().local(); + auto local_vals_range = rng::subrange(local_vals, local_vals + segment_size); + auto local_cols = dsm_->cols_data_->segments()[segment_index_].begin().local(); + auto local_cols_range = rng::subrange(local_cols, local_cols + segment_size); + // auto local_rows = dsm_->rows_data_; + auto zipped_results = rng::views::zip(local_vals_range, local_cols_range); + auto enumerated_zipped = rng::views::enumerate(zipped_results); + auto transformer = [&](auto entry) { + auto [index, pair] = entry; + auto [val, column] = pair; + auto row = 0; //TODO fix calculating row - it results in segfault + // auto row = rng::distance( + // local_rows, + // std::upper_bound(local_rows, local_rows + row_size, offset + index) - + // 1); + dr::index index_obj(row, column); + value_type entry_obj(index_obj, val); + return entry_obj; + }; + auto transformed_res = rng::transform_view(enumerated_zipped, transformer); + return transformed_res.begin(); + } private: // all fields need to be initialized by default ctor so every default