diff --git a/src/stdlib_sparse_spmv.fypp b/src/stdlib_sparse_spmv.fypp index c9cdbcac9..2f2e4bb45 100644 --- a/src/stdlib_sparse_spmv.fypp +++ b/src/stdlib_sparse_spmv.fypp @@ -418,7 +418,7 @@ contains character(1), intent(in), optional :: op ${t1}$ :: alpha_ character(1) :: op_ - integer(ilp) :: i, j, k, nz, rowidx, num_chunks, rm + integer(ilp) :: i, nz, rowidx, num_chunks, rm op_ = sparse_op_none; if(present(op)) op_ = op alpha_ = one_${s1}$ @@ -447,12 +447,7 @@ contains do i = 1, num_chunks nz = ia(i+1) - ia(i) rowidx = (i - 1)*${chunk}$ + 1 - associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), & - & x => vec_x, y => vec_y(rowidx:rowidx+${chunk}$-1) ) - do j = 1, nz - where(col(:,j) > 0) y = y + alpha_ * mat(:,j) * x(col(:,j)) - end do - end associate + call chunk_kernel_${chunk}$(nz,data(:,ia(i)),ja(:,ia(i)),vec_x,vec_y(rowidx:)) end do #:endfor end select @@ -462,12 +457,7 @@ contains i = num_chunks + 1 nz = ia(i+1) - ia(i) rowidx = (i - 1)*cs + 1 - associate(col => ja(1:cs,ia(i):ia(i)+nz-1), mat => data(1:cs,ia(i):ia(i)+nz-1), & - & x => vec_x, y => vec_y(rowidx:rowidx+rm-1) ) - do j = 1, nz - where(col(1:rm,j) > 0) y = y + alpha_ * mat(1:rm,j) * x(col(1:rm,j)) - end do - end associate + call chunk_kernel_rm(nz,cs,rm,data(:,ia(i)),ja(:,ia(i)),vec_x,vec_y(rowidx:)) end if else if( storage == sparse_full .and. op_==sparse_op_transpose ) then @@ -478,14 +468,7 @@ contains do i = 1, num_chunks nz = ia(i+1) - ia(i) rowidx = (i - 1)*${chunk}$ + 1 - associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), & - & x => vec_x(rowidx:rowidx+${chunk}$-1), y => vec_y ) - do j = 1, nz - do k = 1, ${chunk}$ - if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * mat(k,j) * x(k) - end do - end do - end associate + call chunk_kernel_trans_${chunk}$(nz,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y) end do #:endfor end select @@ -495,14 +478,7 @@ contains i = num_chunks + 1 nz = ia(i+1) - ia(i) rowidx = (i - 1)*cs + 1 - associate(col => ja(1:cs,ia(i):ia(i)+nz-1), mat => data(1:cs,ia(i):ia(i)+nz-1), & - & x => vec_x(rowidx:rowidx+rm-1), y => vec_y ) - do j = 1, nz - do k = 1, rm - if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * mat(k,j) * x(k) - end do - end do - end associate + call chunk_kernel_rm_trans(nz,cs,rm,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y) end if #:if t1.startswith('complex') @@ -514,14 +490,7 @@ contains do i = 1, num_chunks nz = ia(i+1) - ia(i) rowidx = (i - 1)*${chunk}$ + 1 - associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), & - & x => vec_x(rowidx:rowidx+${chunk}$-1), y => vec_y ) - do j = 1, nz - do k = 1, ${chunk}$ - if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * conjg(mat(k,j)) * x(k) - end do - end do - end associate + call chunk_kernel_herm_${chunk}$(nz,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y) end do #:endfor end select @@ -531,14 +500,7 @@ contains i = num_chunks + 1 nz = ia(i+1) - ia(i) rowidx = (i - 1)*cs + 1 - associate(col => ja(1:cs,ia(i):ia(i)+nz-1), mat => data(1:cs,ia(i):ia(i)+nz-1), & - & x => vec_x(rowidx:rowidx+rm-1), y => vec_y ) - do j = 1, nz - do k = 1, rm - if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * conjg(mat(k,j)) * x(k) - end do - end do - end associate + call chunk_kernel_rm_herm(nz,cs,rm,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y) end if #:endif else @@ -547,6 +509,83 @@ contains end if end associate + contains + #:for chunk in CHUNKS + pure subroutine chunk_kernel_${chunk}$(n,a,col,x,y) + integer, value :: n + ${t1}$, intent(in) :: a(${chunk}$,n), x(*) + integer(ilp), intent(in) :: col(${chunk}$,n) + ${t1}$, intent(inout) :: y(${chunk}$) + integer :: j + do j = 1, n + where(col(:,j) > 0) y = y + alpha_ * a(:,j) * x(col(:,j)) + end do + end subroutine + pure subroutine chunk_kernel_trans_${chunk}$(n,a,col,x,y) + integer, value :: n + ${t1}$, intent(in) :: a(${chunk}$,n), x(${chunk}$) + integer(ilp), intent(in) :: col(${chunk}$,n) + ${t1}$, intent(inout) :: y(*) + integer :: j, k + do j = 1, n + do k = 1, ${chunk}$ + if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * a(k,j) * x(k) + end do + end do + end subroutine + #:if t1.startswith('complex') + pure subroutine chunk_kernel_herm_${chunk}$(n,a,col,x,y) + integer, value :: n + ${t1}$, intent(in) :: a(${chunk}$,n), x(${chunk}$) + integer(ilp), intent(in) :: col(${chunk}$,n) + ${t1}$, intent(inout) :: y(*) + integer :: j, k + do j = 1, n + do k = 1, ${chunk}$ + if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * conjg(a(k,j)) * x(k) + end do + end do + end subroutine + #:endif + #:endfor + + pure subroutine chunk_kernel_rm(n,cs,r,a,col,x,y) + integer, value :: n, cs, r + ${t1}$, intent(in) :: a(cs,n), x(*) + integer(ilp), intent(in) :: col(cs,n) + ${t1}$, intent(inout) :: y(r) + integer :: j + do j = 1, n + where(col(1:r,j) > 0) y = y + alpha_ * a(1:r,j) * x(col(1:r,j)) + end do + end subroutine + pure subroutine chunk_kernel_rm_trans(n,cs,r,a,col,x,y) + integer, value :: n, cs, r + ${t1}$, intent(in) :: a(cs,n), x(r) + integer(ilp), intent(in) :: col(cs,n) + ${t1}$, intent(inout) :: y(*) + integer :: j, k + do j = 1, n + do k = 1, r + if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * a(k,j) * x(k) + end do + end do + end subroutine + #:if t1.startswith('complex') + pure subroutine chunk_kernel_rm_herm(n,cs,r,a,col,x,y) + integer, value :: n, cs, r + ${t1}$, intent(in) :: a(cs,n), x(r) + integer(ilp), intent(in) :: col(cs,n) + ${t1}$, intent(inout) :: y(*) + integer :: j, k + do j = 1, n + do k = 1, r + if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * conjg(a(k,j)) * x(k) + end do + end do + end subroutine + #:endif + end subroutine #:endfor