Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to the latest broadcast implement. #284

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 49 additions & 35 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,33 +497,53 @@
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
using Base.Broadcast: combine_styles

struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
@static if fieldcount(Base.Broadcast.Broadcasted) == 4
struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
style::S
StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}(style)

Check warning on line 503 in src/structarray.jl

View check run for this annotation

Codecov / codecov/patch

src/structarray.jl#L503

Added line #L503 was not covered by tests
end
StructArrayStyle{N}(style::StructArrayStyle) where {N} = StructArrayStyle{N}(style.style)
parent_style(s::BroadcastStyle) = s
parent_style(s::StructArrayStyle) = s.style
style(bc::Broadcasted) = bc.style

Check warning on line 508 in src/structarray.jl

View check run for this annotation

Codecov / codecov/patch

src/structarray.jl#L505-L508

Added lines #L505 - L508 were not covered by tests
const broadcasted = Broadcasted
else
struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}()
end
StructArrayStyle{N}(style::StructArrayStyle{M, S}) where {N, M, S} = StructArrayStyle{N}(S())
parent_style(s::BroadcastStyle) = s
parent_style(::StructArrayStyle{N, S}) where {N, S} = S()
style(::Broadcasted{Style}) where {Style} = Style()
broadcasted(s, f, args, axes) = Broadcasted{typeof(s)}(f, args, axes)
end
StructArrayStyle{N, S}() where {N, S} = StructArrayStyle{N}(S())
parent_style(bc::Broadcasted) = parent_style(style(bc))
ofstyle(s, bc::Broadcasted) = broadcasted(s, bc.f, bc.args, bc.axes)

# Here we define the dimension tracking behavior of StructArrayStyle
function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
function StructArrayStyle{M, S}(::Val{N}) where {S, M, N}

Check warning on line 525 in src/structarray.jl

View check run for this annotation

Codecov / codecov/patch

src/structarray.jl#L525

Added line #L525 was not covered by tests
T = S <: AbstractArrayStyle{M} ? typeof(S(Val{N}())) : S
return StructArrayStyle{T, N}()
return StructArrayStyle{N, T}()

Check warning on line 527 in src/structarray.jl

View check run for this annotation

Codecov / codecov/patch

src/structarray.jl#L527

Added line #L527 was not covered by tests
end

# StructArrayStyle is a wrapped style.
# Here we try our best to resolve style conflict.
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{N, S}) where {S, N, M}
N′ = M === Any || N === Any ? Any : max(M, N)
S′ = Broadcast.result_style(S(), b)
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
return StructArrayStyle{N′}(Broadcast.result_style(parent_style(a), b))
end
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()

@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(BroadcastStyle(A), args...)
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
combine_style_types(s::BroadcastStyle) = s

Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)

BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{ndims(SA)}(cst(SA))

"""
always_struct_broadcast(style::BroadcastStyle)
Expand Down Expand Up @@ -551,8 +571,8 @@
"""
try_struct_copy(bc::Broadcasted) = copy(bc)

function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
if always_struct_broadcast(S())
function Base.copy(bc::Broadcasted{<:StructArrayStyle})
if always_struct_broadcast(parent_style(bc))
return invoke(copy, Tuple{Broadcasted}, bc)
else
return try_struct_copy(replace_structarray(bc))
Expand All @@ -567,55 +587,49 @@
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
"""
function replace_structarray(bc::Broadcasted{Style}) where {Style}
function replace_structarray(bc::Broadcasted)
args = replace_structarray_args(bc.args)
Style′ = parent_style(Style())
return Broadcasted{Style′}(bc.f, args, bc.axes)
style = parent_style(bc)
return broadcasted(style, bc.f, args, bc.axes)
end
function replace_structarray(A::StructArray)
f = Instantiator(eltype(A))
args = Tuple(components(A))
Style = typeof(combine_styles(args...))
return Broadcasted{Style}(f, args, axes(A))
style = combine_styles(args...)
return broadcasted(style, f, args, axes(A))
end
replace_structarray(@nospecialize(A)) = A

replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
replace_structarray_args(::Tuple{}) = ()

parent_style(@nospecialize(x)) = typeof(x)
parent_style(::StructArrayStyle{S, N}) where {S, N} = S
parent_style(::StructArrayStyle{S, N}) where {N, S<:AbstractArrayStyle{N}} = S
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle{Any}, N} = S
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle, N} = typeof(S(Val(N)))

# `instantiate` and `_axes` might be overloaded for static axes.
function Broadcast.instantiate(bc::Broadcasted{Style}) where {Style <: StructArrayStyle}
Style′ = parent_style(Style())
bc′ = Broadcast.instantiate(convert(Broadcasted{Style′}, bc))
return convert(Broadcasted{Style}, bc′)
function Broadcast.instantiate(bc::Broadcasted{<:StructArrayStyle})
bc′ = Broadcast.instantiate(ofstyle(parent_style(bc), bc))
return ofstyle(style(bc), bc′)
end

function Broadcast._axes(bc::Broadcasted{Style}, ::Nothing) where {Style <: StructArrayStyle}
Style′ = parent_style(Style())
return Broadcast._axes(convert(Broadcasted{Style′}, bc), nothing)
function Broadcast._axes(bc::Broadcasted{<:StructArrayStyle}, ::Nothing)
return Broadcast._axes(ofstyle(parent_style(bc), bc), nothing)
end

# Here we use `similar` defined for `S` to build the dest Array.
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
bc′ = convert(Broadcasted{S}, bc)
function Base.similar(bc::Broadcasted{<:StructArrayStyle}, ::Type{ElType}) where {ElType}
bc′ = ofstyle(parent_style(bc), bc)
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
end

# Unwrapper to recover the behaviour defined by parent style.
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
bc′ = always_struct_broadcast(S()) ? convert(Broadcasted{S}, bc) : replace_structarray(bc)
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:StructArrayStyle})
ps = parent_style(bc)
bc′ = always_struct_broadcast(ps) ? ofstyle(ps, bc) : replace_structarray(bc)
return copyto!(dest, bc′)
end

@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
bc′ = always_struct_broadcast(S()) ? bc : replace_structarray(bc)
return Broadcast.materialize!(S(), dest, bc′)
@inline function Broadcast.materialize!(s::StructArrayStyle, dest, bc::Broadcasted)
ps = parent_style(s)
bc′ = always_struct_broadcast(ps) ? bc : replace_structarray(bc)
return Broadcast.materialize!(ps, dest, bc′)
end

# for aliasing analysis during broadcast
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
ares = map(a->a.re, as)
aims = map(a->a.im, as)
style = Broadcast.combine_styles(ares...)
@test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}()
@test Broadcast.combine_styles(as...) === StructArrayStyle{1,typeof(style)}()
if !(style in tested_style)
push!(tested_style, style)
if style isa Broadcast.ArrayStyle{MyArray3}
Expand All @@ -1249,8 +1249,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

#parent_style
@test StructArrays.parent_style(StructArrayStyle{Broadcast.DefaultArrayStyle{0},2}()) == Broadcast.DefaultArrayStyle{2}
@test StructArrays.parent_style(StructArrayStyle{Broadcast.Style{Tuple},2}()) == Broadcast.Style{Tuple}
@test StructArrays.parent_style(StructArrayStyle{2,Broadcast.DefaultArrayStyle{0}}()) == Broadcast.DefaultArrayStyle{0}()
@test StructArrays.parent_style(StructArrayStyle{2,Broadcast.Style{Tuple}}()) == Broadcast.Style{Tuple}()

# allocation test for overloaded `broadcast_unalias`
StructArrays.always_struct_broadcast(::Broadcast.ArrayStyle{MyArray1}) = false
Expand Down
Loading