From 78e327b6c2bf53312fdb7503f52795c6d1c92d20 Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Tue, 14 Nov 2023 16:05:07 +0000 Subject: [PATCH 01/10] fast mapreduce for specific operators --- src/FillArrays.jl | 13 +------------ src/fillbroadcast.jl | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 7ca7b78e..8cb6ba6c 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -7,7 +7,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, map, zero, show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, - parent, similar, issorted + parent, similar, issorted, add_sum, mul_prod import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, @@ -551,22 +551,11 @@ for SMT in (:Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal) end -######### -# maximum/minimum -######### - -for op in (:maximum, :minimum) - @eval $op(x::AbstractFill) = getindex_value(x) -end - - ######### # Cumsum ######### # These methods are necessary to deal with infinite arrays -sum(x::AbstractFill) = getindex_value(x)*length(x) -sum(f, x::AbstractFill) = length(x) * f(getindex_value(x)) sum(x::AbstractZeros) = getindex_value(x) # needed to support infinite case diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 6cf11284..89897376 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -51,6 +51,21 @@ function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))) end +firstval(a, b) = a +for (op, iterop) in ((:+, :*), (:*, :^), (:add_sum, :mul_prod), (:max, :firstval), (:min, :firstval), (:|, :firstval), (:&, :firstval)) + @eval function Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, dims) + fval = f(getindex_value(A)) + red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) + out = ($iterop)(fval, red) + Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))) + end + @eval function Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, ::Colon) + fval = f(getindex_value(A)) + ($iterop)(fval, length(A)) + end +end + + function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...) val(_...) = f(getindex_value(A), getindex_value(B)) reduce(op, map(val, A, B); kw...) From 25ddeaf607e4929d5decec27394916285709aff3 Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Tue, 14 Nov 2023 16:39:05 +0000 Subject: [PATCH 02/10] Update fillbroadcast.jl --- src/fillbroadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 89897376..f5673d63 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -52,7 +52,7 @@ function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) end firstval(a, b) = a -for (op, iterop) in ((:+, :*), (:*, :^), (:add_sum, :mul_prod), (:max, :firstval), (:min, :firstval), (:|, :firstval), (:&, :firstval)) +for (op, iterop) in ((:+, :*), (:*, :^), (:add_sum, :mul_prod), (:mul_prod, :^), (:max, :firstval), (:min, :firstval), (:|, :firstval), (:&, :firstval)) @eval function Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, dims) fval = f(getindex_value(A)) red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) From 4d74d75e5b99cd90cc0a9edce37e87619a495dff Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Tue, 14 Nov 2023 17:06:56 +0000 Subject: [PATCH 03/10] any/all not needed --- src/FillArrays.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 8cb6ba6c..33e62315 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -629,13 +629,6 @@ function all(f::Function, IM::Eye{T}) where T return all(f(one(T))) end -# In particular, these make iszero(Eye(n)) efficient. -# use any/all on scalar to get Boolean error message -any(f::Function, x::AbstractFill) = !isempty(x) && any(f(getindex_value(x))) -all(f::Function, x::AbstractFill) = isempty(x) || all(f(getindex_value(x))) -any(x::AbstractFill) = any(identity, x) -all(x::AbstractFill) = all(identity, x) - count(x::AbstractOnes{Bool}) = length(x) count(x::AbstractZeros{Bool}) = 0 count(f, x::AbstractFill) = f(getindex_value(x)) ? length(x) : 0 From 97257598de0a50e078951df0911e092eed14d880 Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Wed, 15 Nov 2023 11:10:41 +0000 Subject: [PATCH 04/10] Revert "any/all not needed" This reverts commit 4d74d75e5b99cd90cc0a9edce37e87619a495dff. --- src/FillArrays.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 33e62315..8cb6ba6c 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -629,6 +629,13 @@ function all(f::Function, IM::Eye{T}) where T return all(f(one(T))) end +# In particular, these make iszero(Eye(n)) efficient. +# use any/all on scalar to get Boolean error message +any(f::Function, x::AbstractFill) = !isempty(x) && any(f(getindex_value(x))) +all(f::Function, x::AbstractFill) = isempty(x) || all(f(getindex_value(x))) +any(x::AbstractFill) = any(identity, x) +all(x::AbstractFill) = all(identity, x) + count(x::AbstractOnes{Bool}) = length(x) count(x::AbstractZeros{Bool}) = 0 count(f, x::AbstractFill) = f(getindex_value(x)) ? length(x) : 0 From 7bacf0de4e67d4cb8e38b5cfc28219764f7becca Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Wed, 15 Nov 2023 12:29:46 +0000 Subject: [PATCH 05/10] try again --- src/FillArrays.jl | 6 ++---- src/fillbroadcast.jl | 1 - 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 8cb6ba6c..62ee8b75 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -631,10 +631,8 @@ end # In particular, these make iszero(Eye(n)) efficient. # use any/all on scalar to get Boolean error message -any(f::Function, x::AbstractFill) = !isempty(x) && any(f(getindex_value(x))) -all(f::Function, x::AbstractFill) = isempty(x) || all(f(getindex_value(x))) -any(x::AbstractFill) = any(identity, x) -all(x::AbstractFill) = all(identity, x) +Base._any(f::Function, x::AbstractFill, ::Colon) = !isempty(x) && any(f(getindex_value(x))) +Base._all(f::Function, x::AbstractFill, ::Colon) = isempty(x) || all(f(getindex_value(x))) count(x::AbstractOnes{Bool}) = length(x) count(x::AbstractZeros{Bool}) = 0 diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index f5673d63..86975e51 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -65,7 +65,6 @@ for (op, iterop) in ((:+, :*), (:*, :^), (:add_sum, :mul_prod), (:mul_prod, :^), end end - function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...) val(_...) = f(getindex_value(A), getindex_value(B)) reduce(op, map(val, A, B); kw...) From 469cbed7f2e52fd33beca50963e6f3340900b34a Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Wed, 15 Nov 2023 13:59:15 +0000 Subject: [PATCH 06/10] coverage --- src/fillbroadcast.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 86975e51..9eddf602 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -32,20 +32,20 @@ end ### mapreduce -function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon) +function Base._mapreduce_dim(f, op, nt, A::AbstractFill, ::Colon) fval = f(getindex_value(A)) - out = fval - for _ in 2:length(A) + out = nt + for _ in 1:length(A) out = op(out, fval) end out end -function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) +function Base._mapreduce_dim(f, op, nt, A::AbstractFill, dims) fval = f(getindex_value(A)) red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) - out = fval - for _ in 2:red + out = nt + for _ in 1:red out = op(out, fval) end Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))) From f281594c661c0ed0f7cd8416efd35b9d73bab01f Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Wed, 15 Nov 2023 16:48:11 +0000 Subject: [PATCH 07/10] fix and inline --- src/fillbroadcast.jl | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 9eddf602..2d37a465 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -31,7 +31,8 @@ function _maplinear(rs...) # tries to match Base's behaviour, could perhaps hook end ### mapreduce - +@inline red(A, dims) = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) +@inline outdim(A, dims) = ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A)) function Base._mapreduce_dim(f, op, nt, A::AbstractFill, ::Colon) fval = f(getindex_value(A)) out = nt @@ -40,29 +41,36 @@ function Base._mapreduce_dim(f, op, nt, A::AbstractFill, ::Colon) end out end +function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon) + fval = f(getindex_value(A)) + out = fval + for _ in 2:length(A) + out = op(out, fval) + end + out +end function Base._mapreduce_dim(f, op, nt, A::AbstractFill, dims) fval = f(getindex_value(A)) - red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) out = nt - for _ in 1:red + for _ in 1:red(A, dims) out = op(out, fval) end - Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))) + Fill(out, outdim(A, dims)) +end +function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) + fval = f(getindex_value(A)) + out = fval + for _ in 2:red(A, dims) + out = op(out, fval) + end + Fill(out, outdim(A, dims)) end firstval(a, b) = a for (op, iterop) in ((:+, :*), (:*, :^), (:add_sum, :mul_prod), (:mul_prod, :^), (:max, :firstval), (:min, :firstval), (:|, :firstval), (:&, :firstval)) - @eval function Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, dims) - fval = f(getindex_value(A)) - red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) - out = ($iterop)(fval, red) - Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))) - end - @eval function Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, ::Colon) - fval = f(getindex_value(A)) - ($iterop)(fval, length(A)) - end + @eval Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, dims) = Fill(($iterop)(f(getindex_value(A)), red(A, dims)), outdim(A, dims)) + @eval Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, ::Colon) = ($iterop)(f(getindex_value(A)), length(A)) end function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...) From 9226c33bf6e79290280a3c86816cc7793af7de79 Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Thu, 16 Nov 2023 10:27:29 +0000 Subject: [PATCH 08/10] add iterfun --- src/fillbroadcast.jl | 44 +++++++++++++------------------------------- test/runtests.jl | 2 +- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 2d37a465..58d18aa7 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -33,44 +33,26 @@ end ### mapreduce @inline red(A, dims) = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) @inline outdim(A, dims) = ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A)) -function Base._mapreduce_dim(f, op, nt, A::AbstractFill, ::Colon) - fval = f(getindex_value(A)) +@inline function iterfun(f, n, val, nt=val) out = nt - for _ in 1:length(A) - out = op(out, fval) - end - out -end -function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon) - fval = f(getindex_value(A)) - out = fval - for _ in 2:length(A) - out = op(out, fval) + for _ in 1:n + out = f(out, val) end out end - -function Base._mapreduce_dim(f, op, nt, A::AbstractFill, dims) - fval = f(getindex_value(A)) - out = nt - for _ in 1:red(A, dims) - out = op(out, fval) - end - Fill(out, outdim(A, dims)) -end -function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) - fval = f(getindex_value(A)) - out = fval - for _ in 2:red(A, dims) - out = op(out, fval) - end - Fill(out, outdim(A, dims)) -end +Base._mapreduce_dim(f, op, nt, A::AbstractFill, ::Colon) = iterfun(op, length(A), f(getindex_value(A)), nt) +Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon) = iterfun(op, length(A)-1, f(getindex_value(A))) +Base._mapreduce_dim(f, op, nt, A::AbstractFill, dims) = Fill(iterfun(op, red(A, dims), f(getindex_value(A)), nt), outdim(A, dims)) +Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) = Fill(iterfun(op, red(A, dims)-1, f(getindex_value(A))), outdim(A, dims)) firstval(a, b) = a for (op, iterop) in ((:+, :*), (:*, :^), (:add_sum, :mul_prod), (:mul_prod, :^), (:max, :firstval), (:min, :firstval), (:|, :firstval), (:&, :firstval)) - @eval Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, dims) = Fill(($iterop)(f(getindex_value(A)), red(A, dims)), outdim(A, dims)) - @eval Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, ::Colon) = ($iterop)(f(getindex_value(A)), length(A)) + @eval begin + Base._mapreduce_dim(f, ::typeof($op), nt, A::AbstractFill, dims) = Fill($op(nt, $iterop(f(getindex_value(A)), red(A, dims))), outdim(A, dims)) + Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, dims) = Fill(($iterop)(f(getindex_value(A)), red(A, dims)), outdim(A, dims)) + Base._mapreduce_dim(f, ::typeof($op), nt, A::AbstractFill, ::Colon) = $op(nt, ($iterop)(f(getindex_value(A)), length(A))) + Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, ::Colon) = ($iterop)(f(getindex_value(A)), length(A)) + end end function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...) diff --git a/test/runtests.jl b/test/runtests.jl index 50d42c54..de0a220b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1125,7 +1125,7 @@ end @test mapreduce(identity, +, Y) == sum(y) == sum(Y) @test mapreduce(identity, +, Y, dims=1) == sum(y, dims=1) == sum(Y, dims=1) - @test mapreduce(exp, +, Y; dims=(1,), init=5.0) == mapreduce(exp, +, y; dims=(1,), init=5.0) + @test isapprox(mapreduce(exp, +, Y; dims=(1,), init=5.0), mapreduce(exp, +, y; dims=(1,), init=5.0), rtol=eps()) # Two arrays @test mapreduce(*, +, x, Y) == mapreduce(*, +, x, y) From f20f015cfe44142244b2e439b0d5bf4716c6cd62 Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Thu, 16 Nov 2023 11:31:47 +0000 Subject: [PATCH 09/10] add tests --- test/runtests.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index de0a220b..4b83baa0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1120,12 +1120,17 @@ end Y = Fill(1.0, 3, 4) O = Ones(3, 4) + op2(x,y) = x^2 + 3y + @test mapreduce(exp, +, Y) == mapreduce(exp, +, y) + @test mapreduce(exp, op2, Y) == mapreduce(exp, op2, y) @test mapreduce(exp, +, Y; dims=2) == mapreduce(exp, +, y; dims=2) @test mapreduce(identity, +, Y) == sum(y) == sum(Y) @test mapreduce(identity, +, Y, dims=1) == sum(y, dims=1) == sum(Y, dims=1) @test isapprox(mapreduce(exp, +, Y; dims=(1,), init=5.0), mapreduce(exp, +, y; dims=(1,), init=5.0), rtol=eps()) + @test isapprox(mapreduce(exp, op2, Y; dims=(1,), init=5.0), mapreduce(exp, op2, y; dims=(1,), init=5.0), rtol=eps()) + @test isapprox(mapreduce(exp, op2, Y; init=5.0), mapreduce(exp, op2, y; init=5.0), rtol=eps()) # Two arrays @test mapreduce(*, +, x, Y) == mapreduce(*, +, x, y) @@ -1134,7 +1139,6 @@ end @test mapreduce(*, +, Y, O) == mapreduce(*, +, y, y) f2(x,y) = 1 + x/y - op2(x,y) = x^2 + 3y @test mapreduce(f2, op2, x, Y) == mapreduce(f2, op2, x, y) @test mapreduce(f2, op2, x, Y, dims=1, init=5.0) == mapreduce(f2, op2, x, y, dims=1, init=5.0) From e0854bbb70128f4c47a766d3d89049a40adb7a09 Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Thu, 16 Nov 2023 16:13:45 +0000 Subject: [PATCH 10/10] add tests --- test/runtests.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4b83baa0..0a16482f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1123,14 +1123,15 @@ end op2(x,y) = x^2 + 3y @test mapreduce(exp, +, Y) == mapreduce(exp, +, y) - @test mapreduce(exp, op2, Y) == mapreduce(exp, op2, y) @test mapreduce(exp, +, Y; dims=2) == mapreduce(exp, +, y; dims=2) @test mapreduce(identity, +, Y) == sum(y) == sum(Y) @test mapreduce(identity, +, Y, dims=1) == sum(y, dims=1) == sum(Y, dims=1) @test isapprox(mapreduce(exp, +, Y; dims=(1,), init=5.0), mapreduce(exp, +, y; dims=(1,), init=5.0), rtol=eps()) - @test isapprox(mapreduce(exp, op2, Y; dims=(1,), init=5.0), mapreduce(exp, op2, y; dims=(1,), init=5.0), rtol=eps()) - @test isapprox(mapreduce(exp, op2, Y; init=5.0), mapreduce(exp, op2, y; init=5.0), rtol=eps()) + @test mapreduce(exp, op2, Y; dims=(1,), init=5.0) == mapreduce(exp, op2, y; dims=(1,), init=5.0) + @test mapreduce(exp, op2, Y; dims=(1,)) == [mapreduce(exp, op2, y[:,k]) for k in 1:4]' # see https://github.com/JuliaLang/julia/issues/52188 + @test mapreduce(exp, op2, Y; init=5.0) == mapreduce(exp, op2, y; init=5.0) + @test mapreduce(exp, op2, Y) == mapreduce(exp, op2, y) # Two arrays @test mapreduce(*, +, x, Y) == mapreduce(*, +, x, y)