Skip to content

Commit

Permalink
define __broadcast ourselves
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Nov 8, 2023
1 parent b4de96b commit 60a8c8c
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions ext/StructArraysStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArray

# Broadcast overload
using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype
using StructArrays: isnonemptystructtype
using Base.Broadcast: Broadcasted
using Base.Broadcast: Broadcasted, _broadcast_getindex

# StaticArrayStyle has no similar defined.
# Overload `try_struct_copy` instead.
Expand Down Expand Up @@ -79,4 +79,29 @@ end
end
end

# The `__broadcast` kernal is copied from `StaticArrays.jl`.
# see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
sizes = [sz.parameters[1] for sz s.parameters]

indices = CartesianIndices(newsize)
exprs = similar(indices, Expr)
for (j, current_ind) enumerate(indices)
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
exprs[j] = :(f($(exprs_vals...)))
end

return quote
Base.@_inline_meta
return tuple($(exprs...))
end
end

broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
li = LinearIndices(oldsize)
ind = _broadcast_getindex(li, newindex)
return :(a[$i][$ind])
end

end

0 comments on commit 60a8c8c

Please sign in to comment.