diff --git a/vendor/sll/include/sll/matrix_batch_csr.hpp b/vendor/sll/include/sll/matrix_batch_csr.hpp index 806647a83..c736c7b97 100644 --- a/vendor/sll/include/sll/matrix_batch_csr.hpp +++ b/vendor/sll/include/sll/matrix_batch_csr.hpp @@ -189,7 +189,15 @@ class MatrixBatchCsr : public MatrixBatch void setup_solver() final { std::shared_ptr const gko_exec = m_batch_matrix_csr->get_executor(); - + //Check if the indices array is sorted, and sort it if necessary. + //The values array corresponding to the indices is also reordered. + for (size_t i = 0; i < batch_size(); i++) { + std::unique_ptr> tmp_matrix + = m_batch_matrix_csr->create_view_for_item(i); + if (!tmp_matrix->is_sorted_by_column_index()) { + tmp_matrix->sort_by_column_index(); + } + } if constexpr ( Solver == MatrixBatchCsrSolver::CG || Solver == MatrixBatchCsrSolver::BICGSTAB) { // Create the solver factory @@ -212,7 +220,7 @@ class MatrixBatchCsr : public MatrixBatch .on(gko_exec); // Create the solvers - for (int i = 0; i < batch_size(); i++) { + for (size_t i = 0; i < batch_size(); i++) { m_solver.emplace_back(solver_factory->generate( m_batch_matrix_csr->create_const_view_for_item(i))); } @@ -247,7 +255,7 @@ class MatrixBatchCsr : public MatrixBatch if constexpr ( Solver == MatrixBatchCsrSolver::CG || Solver == MatrixBatchCsrSolver::BICGSTAB) { - for (int i = 0; i < batch_size(); i++) { + for (size_t i = 0; i < batch_size(); i++) { std::shared_ptr const gko_exec = m_solver[i]->get_executor(); // Create a logger to obtain the iteration counts and "implicit" residual norms for every system after the solve.