Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make vector times matrix faster #1937

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions src/flint/fmpz_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1784,16 +1784,6 @@
addmul!(z::ZZMatrixOrPtr, a::ZZMatrixOrPtr, b::Integer) = addmul!(z, a, flintify(b))
addmul!(z::ZZMatrixOrPtr, a::IntegerUnionOrPtr, b::ZZMatrixOrPtr) = addmul!(z, b, a)

function mul!(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem})
@ccall libflint.fmpz_mat_mul_fmpz_vec_ptr(z::Ptr{Ref{ZZRingElem}}, a::Ref{ZZMatrix}, b::Ptr{Ref{ZZRingElem}}, length(b)::Int)::Nothing
return z
end

function mul!(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr)
@ccall libflint.fmpz_mat_fmpz_vec_mul_ptr(z::Ptr{Ref{ZZRingElem}}, a::Ptr{Ref{ZZRingElem}}, length(a)::Int, b::Ref{ZZMatrix})::Nothing
return z
end

function Generic.add_one!(a::ZZMatrix, i::Int, j::Int)
@boundscheck _checkbounds(a, i, j)
GC.@preserve a begin
Expand All @@ -1819,6 +1809,96 @@
return g
end

################################################################################
#
# Vector * Matrix and Matrix * Vector
#
################################################################################

# Vector{fmpz} * fmpz_mat can be performed using
# - fmpz_mat_fmpz_vec_mul_ptr
# - or conversion + fmpz_mat_mul
#
# The fmpz_mat_fmpz_vec_mul_ptr variants are not optimized.
# Thus, if the conversion is negliable, we convert and call fmpz_mat.
# The conversion is done on the julia side, trying to reduce the number of
# allocations and objects tracked by the GC.

function _very_unsafe_convert(::Type{ZZMatrix}, a::Vector{ZZRingElem}, row = true)

Check warning on line 1827 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1827

Added line #L1827 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function _very_unsafe_convert(::Type{ZZMatrix}, a::Vector{ZZRingElem}, row = true)
function _very_unsafe_convert(::Type{ZZMatrix}, a::Vector{ZZRingElem}, Val{row} = Val(true)) where {row}

this should push performance ever so slightly further as it eliminates a jump in if row

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt this is measurable

# a must be GC.@preserved
# row = true -> make it a row
# row = false -> make it a column
M = Nemo.@new_struct(ZZMatrix)
Me = zeros(Int, length(a))
M.entries = reinterpret(Ptr{ZZRingElem}, pointer(Me))
if row
Mep = [pointer(Me)]
M.rows = reinterpret(Ptr{Ptr{ZZRingElem}}, pointer(Mep))
M.r = 1
M.c = length(a)

Check warning on line 1838 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1831-L1838

Added lines #L1831 - L1838 were not covered by tests
else
M.r = length(a)
M.c = 1
Mep = [pointer(Me) + 8*(i - 1) for i in 1:length(a)]
M.rows = reinterpret(Ptr{Ptr{ZZRingElem}}, pointer(Mep))

Check warning on line 1843 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1840-L1843

Added lines #L1840 - L1843 were not covered by tests
end
for i in 1:length(a)
Me[i] = a[i].d
end
return M, Me, Mep

Check warning on line 1848 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1845-L1848

Added lines #L1845 - L1848 were not covered by tests
end

function mul!_flint(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem})
ccall((:fmpz_mat_mul_fmpz_vec_ptr, libflint), Nothing,
(Ptr{Ref{ZZRingElem}}, Ref{ZZMatrix}, Ptr{Ref{ZZRingElem}}, Int),
z, a, b, length(b))
return z
end

function mul!_flint(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr)
ccall((:fmpz_mat_fmpz_vec_mul_ptr, libflint), Nothing,
(Ptr{Ref{ZZRingElem}}, Ptr{Ref{ZZRingElem}}, Int, Ref{ZZMatrix}),
z, a, length(a), b)
return z
end

function mul!(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem})
# cutoff for the flint method
if nrows(a) < 50 && maximum(nbits, a) < 10
return mul!_flint(z, a, b)
end

GC.@preserve z b begin
bb, dk1, dk2 = _very_unsafe_convert(ZZMatrix, b, false)
zz, dk3, dk4 = _very_unsafe_convert(ZZMatrix, z, false)
GC.@preserve dk1 dk2 dk3 dk4 begin
mul!(zz, a, bb)
for i in 1:length(z)
z[i].d = unsafe_load(zz.entries, i).d
end

Check warning on line 1878 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1871-L1878

Added lines #L1871 - L1878 were not covered by tests
end
end
return z

Check warning on line 1881 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1881

Added line #L1881 was not covered by tests
end

function mul!(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr)
# cutoff for the flint method
if nrows(b) < 50 && maximum(nbits, b) < 10
return mul!_flint(z, a, b)
end
GC.@preserve z a begin
aa, dk1, dk2 = _very_unsafe_convert(ZZMatrix, a)
zz, dk3, dk4 = _very_unsafe_convert(ZZMatrix, z)
GC.@preserve dk1 dk2 dk3 dk4 begin
mul!(zz, aa, b)
for i in 1:length(z)
z[i].d = unsafe_load(zz.entries, i).d
end

Check warning on line 1896 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1889-L1896

Added lines #L1889 - L1896 were not covered by tests
end
end
return z

Check warning on line 1899 in src/flint/fmpz_mat.jl

View check run for this annotation

Codecov / codecov/patch

src/flint/fmpz_mat.jl#L1899

Added line #L1899 was not covered by tests
end

###############################################################################
#
# Parent object call overloads
Expand Down
Loading