From 61a40dd845c9e47d55bae3400ab8200f5166a756 Mon Sep 17 00:00:00 2001 From: Tommy Hofmann Date: Thu, 31 Oct 2024 23:11:31 +0100 Subject: [PATCH] feat: make Vector{fmpz} * fmpz_mat faster --- src/flint/fmpz_mat.jl | 100 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 90 insertions(+), 10 deletions(-) diff --git a/src/flint/fmpz_mat.jl b/src/flint/fmpz_mat.jl index 86bf7744b..dda2c543e 100644 --- a/src/flint/fmpz_mat.jl +++ b/src/flint/fmpz_mat.jl @@ -1784,16 +1784,6 @@ end addmul!(z::ZZMatrixOrPtr, a::ZZMatrixOrPtr, b::Integer) = addmul!(z, a, flintify(b)) addmul!(z::ZZMatrixOrPtr, a::IntegerUnionOrPtr, b::ZZMatrixOrPtr) = addmul!(z, b, a) -function mul!(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem}) - @ccall libflint.fmpz_mat_mul_fmpz_vec_ptr(z::Ptr{Ref{ZZRingElem}}, a::Ref{ZZMatrix}, b::Ptr{Ref{ZZRingElem}}, length(b)::Int)::Nothing - return z -end - -function mul!(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr) - @ccall libflint.fmpz_mat_fmpz_vec_mul_ptr(z::Ptr{Ref{ZZRingElem}}, a::Ptr{Ref{ZZRingElem}}, length(a)::Int, b::Ref{ZZMatrix})::Nothing - return z -end - function Generic.add_one!(a::ZZMatrix, i::Int, j::Int) @boundscheck _checkbounds(a, i, j) GC.@preserve a begin @@ -1819,6 +1809,96 @@ function shift!(g::ZZMatrix, l::Int) return g end +################################################################################ +# +# Vector * Matrix and Matrix * Vector +# +################################################################################ + +# Vector{fmpz} * fmpz_mat can be performed using +# - fmpz_mat_fmpz_vec_mul_ptr +# - or conversion + fmpz_mat_mul +# +# The fmpz_mat_fmpz_vec_mul_ptr variants are not optimized. +# Thus, if the conversion is negliable, we convert and call fmpz_mat. +# The conversion is done on the julia side, trying to reduce the number of +# allocations and objects tracked by the GC. + +function _very_unsafe_convert(::Type{ZZMatrix}, a::Vector{ZZRingElem}, row = true) + # a must be GC.@preserved + # row = true -> make it a row + # row = false -> make it a column + M = Nemo.@new_struct(ZZMatrix) + Me = zeros(Int, length(a)) + M.entries = reinterpret(Ptr{ZZRingElem}, pointer(Me)) + if row + Mep = [pointer(Me)] + M.rows = reinterpret(Ptr{Ptr{ZZRingElem}}, pointer(Mep)) + M.r = 1 + M.c = length(a) + else + M.r = length(a) + M.c = 1 + Mep = [pointer(Me) + 8*(i - 1) for i in 1:length(a)] + M.rows = reinterpret(Ptr{Ptr{ZZRingElem}}, pointer(Mep)) + end + for i in 1:length(a) + Me[i] = a[i].d + end + return M, Me, Mep +end + +function mul!_flint(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem}) + ccall((:fmpz_mat_mul_fmpz_vec_ptr, libflint), Nothing, + (Ptr{Ref{ZZRingElem}}, Ref{ZZMatrix}, Ptr{Ref{ZZRingElem}}, Int), + z, a, b, length(b)) + return z +end + +function mul!_flint(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr) + ccall((:fmpz_mat_fmpz_vec_mul_ptr, libflint), Nothing, + (Ptr{Ref{ZZRingElem}}, Ptr{Ref{ZZRingElem}}, Int, Ref{ZZMatrix}), + z, a, length(a), b) + return z +end + +function mul!(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem}) + # cutoff for the flint method + if nrows(a) < 50 && maximum(nbits, a) < 10 + return mul!_flint(z, a, b) + end + + GC.@preserve z b begin + bb, dk1, dk2 = _very_unsafe_convert(ZZMatrix, b, false) + zz, dk3, dk4 = _very_unsafe_convert(ZZMatrix, z, false) + GC.@preserve dk1 dk2 dk3 dk4 begin + mul!(zz, a, bb) + for i in 1:length(z) + z[i].d = unsafe_load(zz.entries, i).d + end + end + end + return z +end + +function mul!(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr) + # cutoff for the flint method + if nrows(b) < 50 && maximum(nbits, b) < 10 + return mul!_flint(z, a, b) + end + GC.@preserve z a begin + aa, dk1, dk2 = _very_unsafe_convert(ZZMatrix, a) + zz, dk3, dk4 = _very_unsafe_convert(ZZMatrix, z) + GC.@preserve dk1 dk2 dk3 dk4 begin + mul!(zz, aa, b) + for i in 1:length(z) + z[i].d = unsafe_load(zz.entries, i).d + end + end + end + return z +end + ############################################################################### # # Parent object call overloads