Skip to content

Commit

Permalink
move current staticarray support to Ext
Browse files Browse the repository at this point in the history
Fully omit extra allocation in staticstructbroadcast.
  • Loading branch information
N5N3 committed May 30, 2023
1 parent 033ff58 commit b5ea6ea
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 83 deletions.
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,23 @@ version = "0.6.15"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[extensions]
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysStaticArraysCoreExt = "StaticArraysCore"
StructArraysStaticArraysExt = "StaticArrays"
StructArraysTablesExt = "Tables"

[compat]
Adapt = "1, 2, 3"
DataAPI = "1"
GPUArraysCore = "0.1.2"
StaticArrays = "1.5.6"
StaticArraysCore = "1.3"
Tables = "1"
julia = "1.6"

Expand All @@ -43,4 +41,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "StaticArraysCore", "Tables"]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "Tables"]
77 changes: 0 additions & 77 deletions ext/StructArraysStaticArraysCoreExt.jl

This file was deleted.

82 changes: 82 additions & 0 deletions ext/StructArraysStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module StructArraysStaticArraysExt

using StructArrays
using StaticArrays: StaticArray, FieldArray, tuple_prod

"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`.
```julia
julia> StructArrays.staticschema(SVector{2, Float64})
Tuple{Float64, Float64}
```
The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a
struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct
which subtypes `FieldArray`.
"""
@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
return quote
Base.@_inline_meta
return NTuple{$(tuple_prod(S)), T}
end
end
StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args)
StructArrays.component(s::StaticArray, i) = getindex(s, i)

# invoke general fallbacks for a `FieldArray` type.
@inline function StructArrays.staticschema(T::Type{<:FieldArray})
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
end
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

# Broadcast overload
using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast
using StructArrays: isnonemptystructtype
using Base.Broadcast: Broadcasted

# StaticArrayStyle has no similar defined.
# Overload `try_struct_copy` instead.
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
flat = broadcast_flatten(bc); as = flat.args; f = flat.f
argsizes = broadcast_sizes(as...)
ax = axes(bc)
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.")
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
end

@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize}
first_staticarray = first_statictype(a...)
elements, ET = if prod(newsize) == 0
# Use inference to get eltype in empty case (see also comments in _map)
eltys = Tuple{map(eltype, a)...}
(), Core.Compiler.return_type(f, eltys)
else
temp = __broadcast(f, sz, s, a...)
temp, eltype(temp)
end
if isnonemptystructtype(ET)
@static if VERSION >= v"1.7"
arrs = ntuple(Val(fieldcount(ET))) do i
@inbounds similar_type(first_staticarray, fieldtype(ET, i), sz)(_getfields(elements, i))
end
else
similarET(::Type{SA}, ::Type{T}) where {SA, T} = i -> @inbounds similar_type(SA, fieldtype(T, i), sz)(_getfields(elements, i))
arrs = ntuple(similarET(first_staticarray, ET), Val(fieldcount(ET)))
end
return StructArray{ET}(arrs)
end
@inbounds return similar_type(first_staticarray, ET, sz)(elements)
end

@inline function _getfields(x::Tuple, i::Int)
if @generated
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
else
return map(Base.Fix2(getfield, i), x)
end
end

end
2 changes: 1 addition & 1 deletion src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ See also [`always_struct_broadcast`](@ref).
"""
try_struct_copy(bc::Broadcasted) = copy(bc)

function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
@inline function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
if always_struct_broadcast(S())
return invoke(copy, Tuple{Broadcasted}, bc)
else
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1297,8 +1297,10 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS

@testset "allocation test" begin
a = StructArray{ComplexF64}(undef, 1)
sa = StructArray{ComplexF64}((SizedVector{1}(a.re), SizedVector{1}(a.re)))
allocated(a) = @allocated a .+ 1
@test allocated(a) == 2allocated(a.re)
@test allocated(sa) == 2allocated(sa.re)
allocated2(a) = @allocated a .= complex.(a.im, a.re)
@test allocated2(a) == 0
end
Expand Down

0 comments on commit b5ea6ea

Please sign in to comment.