diff --git a/src/structarray.jl b/src/structarray.jl index ee361c39..57fc99ab 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -497,20 +497,41 @@ end import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict using Base.Broadcast: combine_styles -struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end +@static if fieldcount(Base.Broadcast.Broadcasted) == 4 + struct StructArrayStyle{N, S} <: AbstractArrayStyle{N} + style::S + StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}(style) + end + StructArrayStyle{N}(style::StructArrayStyle) where {N} = StructArrayStyle{N}(style.style) + parent_style(s::BroadcastStyle) = s + parent_style(s::StructArrayStyle) = s.style + style(bc::Broadcasted) = bc.style + const broadcasted = Broadcasted +else + struct StructArrayStyle{N, S} <: AbstractArrayStyle{N} + StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}() + end + StructArrayStyle{N}(style::StructArrayStyle{M, S}) where {N, M, S} = StructArrayStyle{N}(S()) + parent_style(s::BroadcastStyle) = s + parent_style(::StructArrayStyle{N, S}) where {N, S} = S() + style(::Broadcasted{Style}) where {Style} = Style() + broadcasted(s, f, args, axes) = Broadcasted{typeof(s)}(f, args, axes) +end +StructArrayStyle{N, S}() where {N, S} = StructArrayStyle{N}(S()) +parent_style(bc::Broadcasted) = parent_style(style(bc)) +ofstyle(s, bc::Broadcasted) = broadcasted(s, bc.f, bc.args, bc.axes) # Here we define the dimension tracking behavior of StructArrayStyle -function StructArrayStyle{S, M}(::Val{N}) where {S, M, N} +function StructArrayStyle{M, S}(::Val{N}) where {S, M, N} T = S <: AbstractArrayStyle{M} ? typeof(S(Val{N}())) : S - return StructArrayStyle{T, N}() + return StructArrayStyle{N, T}() end # StructArrayStyle is a wrapped style. # Here we try our best to resolve style conflict. -function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M} +function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{N, S}) where {S, N, M} N′ = M === Any || N === Any ? Any : max(M, N) - S′ = Broadcast.result_style(S(), b) - return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}() + return StructArrayStyle{N′}(Broadcast.result_style(parent_style(a), b)) end BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown() @@ -518,12 +539,11 @@ BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown() combine_style_types(BroadcastStyle(A), args...) @inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} = combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...) -combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle combine_style_types(s::BroadcastStyle) = s Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...) -BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}() +BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{ndims(SA)}(cst(SA)) """ always_struct_broadcast(style::BroadcastStyle) @@ -551,8 +571,8 @@ See also [`always_struct_broadcast`](@ref). """ try_struct_copy(bc::Broadcasted) = copy(bc) -function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N} - if always_struct_broadcast(S()) +function Base.copy(bc::Broadcasted{<:StructArrayStyle}) + if always_struct_broadcast(parent_style(bc)) return invoke(copy, Tuple{Broadcasted}, bc) else return try_struct_copy(replace_structarray(bc)) @@ -567,55 +587,49 @@ an equivalent one without it. This is not a must if the root `BroadcastStyle` supports `AbstractArray`. But some `BroadcastStyle` limits the input array types, e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`. """ -function replace_structarray(bc::Broadcasted{Style}) where {Style} +function replace_structarray(bc::Broadcasted) args = replace_structarray_args(bc.args) - Style′ = parent_style(Style()) - return Broadcasted{Style′}(bc.f, args, bc.axes) + style = parent_style(bc) + return broadcasted(style, bc.f, args, bc.axes) end function replace_structarray(A::StructArray) f = Instantiator(eltype(A)) args = Tuple(components(A)) - Style = typeof(combine_styles(args...)) - return Broadcasted{Style}(f, args, axes(A)) + style = combine_styles(args...) + return broadcasted(style, f, args, axes(A)) end replace_structarray(@nospecialize(A)) = A replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...) replace_structarray_args(::Tuple{}) = () -parent_style(@nospecialize(x)) = typeof(x) -parent_style(::StructArrayStyle{S, N}) where {S, N} = S -parent_style(::StructArrayStyle{S, N}) where {N, S<:AbstractArrayStyle{N}} = S -parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle{Any}, N} = S -parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle, N} = typeof(S(Val(N))) - # `instantiate` and `_axes` might be overloaded for static axes. -function Broadcast.instantiate(bc::Broadcasted{Style}) where {Style <: StructArrayStyle} - Style′ = parent_style(Style()) - bc′ = Broadcast.instantiate(convert(Broadcasted{Style′}, bc)) - return convert(Broadcasted{Style}, bc′) +function Broadcast.instantiate(bc::Broadcasted{<:StructArrayStyle}) + bc′ = Broadcast.instantiate(ofstyle(parent_style(bc), bc)) + return ofstyle(style(bc), bc′) end -function Broadcast._axes(bc::Broadcasted{Style}, ::Nothing) where {Style <: StructArrayStyle} - Style′ = parent_style(Style()) - return Broadcast._axes(convert(Broadcasted{Style′}, bc), nothing) +function Broadcast._axes(bc::Broadcasted{<:StructArrayStyle}, ::Nothing) + return Broadcast._axes(ofstyle(parent_style(bc), bc), nothing) end # Here we use `similar` defined for `S` to build the dest Array. -function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType} - bc′ = convert(Broadcasted{S}, bc) +function Base.similar(bc::Broadcasted{<:StructArrayStyle}, ::Type{ElType}) where {ElType} + bc′ = ofstyle(parent_style(bc), bc) return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType) end # Unwrapper to recover the behaviour defined by parent style. -@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N} - bc′ = always_struct_broadcast(S()) ? convert(Broadcasted{S}, bc) : replace_structarray(bc) +@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:StructArrayStyle}) + ps = parent_style(bc) + bc′ = always_struct_broadcast(ps) ? ofstyle(ps, bc) : replace_structarray(bc) return copyto!(dest, bc′) end -@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S} - bc′ = always_struct_broadcast(S()) ? bc : replace_structarray(bc) - return Broadcast.materialize!(S(), dest, bc′) +@inline function Broadcast.materialize!(s::StructArrayStyle, dest, bc::Broadcasted) + ps = parent_style(s) + bc′ = always_struct_broadcast(ps) ? bc : replace_structarray(bc) + return Broadcast.materialize!(ps, dest, bc′) end # for aliasing analysis during broadcast diff --git a/test/runtests.jl b/test/runtests.jl index 85a3637d..109c36a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1227,7 +1227,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS ares = map(a->a.re, as) aims = map(a->a.im, as) style = Broadcast.combine_styles(ares...) - @test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}() + @test Broadcast.combine_styles(as...) === StructArrayStyle{1,typeof(style)}() if !(style in tested_style) push!(tested_style, style) if style isa Broadcast.ArrayStyle{MyArray3} @@ -1249,8 +1249,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}} #parent_style - @test StructArrays.parent_style(StructArrayStyle{Broadcast.DefaultArrayStyle{0},2}()) == Broadcast.DefaultArrayStyle{2} - @test StructArrays.parent_style(StructArrayStyle{Broadcast.Style{Tuple},2}()) == Broadcast.Style{Tuple} + @test StructArrays.parent_style(StructArrayStyle{2,Broadcast.DefaultArrayStyle{0}}()) == Broadcast.DefaultArrayStyle{0}() + @test StructArrays.parent_style(StructArrayStyle{2,Broadcast.Style{Tuple}}()) == Broadcast.Style{Tuple}() # allocation test for overloaded `broadcast_unalias` StructArrays.always_struct_broadcast(::Broadcast.ArrayStyle{MyArray1}) = false