Skip to content

Commit

Permalink
feat: make Vector{fmpz} * fmpz_mat faster
Browse files Browse the repository at this point in the history
  • Loading branch information
thofma committed Nov 19, 2024
1 parent 80d1172 commit f81db55
Showing 1 changed file with 82 additions and 2 deletions.
84 changes: 82 additions & 2 deletions src/flint/fmpz_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1784,13 +1784,93 @@ end
addmul!(z::ZZMatrixOrPtr, a::ZZMatrixOrPtr, b::Integer) = addmul!(z, a, flintify(b))
addmul!(z::ZZMatrixOrPtr, a::IntegerUnionOrPtr, b::ZZMatrixOrPtr) = addmul!(z, b, a)

################################################################################
#
# Vector * Matrix and Matrix * Vector
#
################################################################################

# Vector{fmpz} * fmpz_mat can be performed using
# - fmpz_mat_fmpz_vec_mul_ptr
# - or conversion + fmpz_mat_mul
#
# The fmpz_mat_fmpz_vec_mul_ptr variants are not optimized.
# Thus, if the conversion is negliable, we convert and call fmpz_mat.
# The conversion is done on the julia side, trying to reduce the number of
# allocations and objects tracked by the GC.

function _very_unsafe_convert(::Type{ZZMatrix}, a::Vector{ZZRingElem}, row = true)

Check warning on line 1802 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1802

Added line #L1802 was not covered by tests
# a must be GC.@preserved
# row = true -> make it a row
# row = false -> make it a column
M = Nemo.@new_struct(ZZMatrix)
Me = zeros(Int, length(a))
M.entries = reinterpret(Ptr{ZZRingElem}, pointer(Me))
if row
Mep = [pointer(Me)]
M.rows = reinterpret(Ptr{Ptr{ZZRingElem}}, pointer(Mep))
M.r = 1
M.c = length(a)

Check warning on line 1813 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1806-L1813

Added lines #L1806 - L1813 were not covered by tests
else
M.r = length(a)
M.c = 1
Mep = [pointer(Me) + 8*(i - 1) for i in 1:length(a)]
M.rows = reinterpret(Ptr{Ptr{ZZRingElem}}, pointer(Mep))

Check warning on line 1818 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1815-L1818

Added lines #L1815 - L1818 were not covered by tests
end
for i in 1:length(a)
Me[i] = a[i].d
end
return M, Me, Mep

Check warning on line 1823 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1820-L1823

Added lines #L1820 - L1823 were not covered by tests
end

function mul!(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem})
@ccall libflint.fmpz_mat_mul_fmpz_vec_ptr(z::Ptr{Ref{ZZRingElem}}, a::Ref{ZZMatrix}, b::Ptr{Ref{ZZRingElem}}, length(b)::Int)::Nothing
# cutoff for the flint method
if nrows(a) < 50 && maximum(nbits, a) < 10
return mul!_flint(z, a, b)
end

GC.@preserve z b begin
bb, dk1, dk2 = _very_unsafe_convert(ZZMatrix, b, false)
zz, dk3, dk4 = _very_unsafe_convert(ZZMatrix, z, false)
GC.@preserve dk1 dk2 dk3 dk4 begin
mul!(zz, a, bb)
for i in 1:length(z)
z[i].d = unsafe_load(zz.entries, i).d
end

Check warning on line 1839 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1832-L1839

Added lines #L1832 - L1839 were not covered by tests
end
end
return z

Check warning on line 1842 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1842

Added line #L1842 was not covered by tests
end

function mul!_flint(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem})
ccall((:fmpz_mat_mul_fmpz_vec_ptr, libflint), Nothing,
(Ptr{Ref{ZZRingElem}}, Ref{ZZMatrix}, Ptr{Ref{ZZRingElem}}, Int),
z, a, b, length(b))
return z
end

function mul!_flint(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr)
ccall((:fmpz_mat_fmpz_vec_mul_ptr, libflint), Nothing,
(Ptr{Ref{ZZRingElem}}, Ptr{Ref{ZZRingElem}}, Int, Ref{ZZMatrix}),
z, a, length(a), b)
return z
end

function mul!(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr)
@ccall libflint.fmpz_mat_fmpz_vec_mul_ptr(z::Ptr{Ref{ZZRingElem}}, a::Ptr{Ref{ZZRingElem}}, length(a)::Int, b::Ref{ZZMatrix})::Nothing
# cutoff for the flint method
if nrows(b) < 50 && maximum(nbits, b) < 10
return mul!_flint(z, a, b)
end
GC.@preserve z a begin
aa, dk1, dk2 = _very_unsafe_convert(ZZMatrix, a)
zz, dk3, dk4 = _very_unsafe_convert(ZZMatrix, z)
GC.@preserve dk1 dk2 dk3 dk4 begin
mul!(zz, aa, b)
for i in 1:length(z)
z[i].d = unsafe_load(zz.entries, i).d
end

Check warning on line 1871 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1864-L1871

Added lines #L1864 - L1871 were not covered by tests
end
end
return z
end

Expand Down

0 comments on commit f81db55

Please sign in to comment.