Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize on diagonal fieldvector broadcasts #1615

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,93 @@ end
@inline transform_broadcasted(fv::FieldVector, symb, axes) =
parent(getfield(_values(fv), symb))
@inline transform_broadcasted(x, symb, axes) = x

@inline function first_fieldvector_in_bc(args::Tuple, rargs...)
x1 = first_fieldvector_in_bc(args[1], rargs...)
x1 isa FieldVector && return x1
return first_fieldvector_in_bc(Base.tail(args), rargs...)
end

@inline first_fieldvector_in_bc(args::Tuple{Any}, rargs...) =
first_fieldvector_in_bc(args[1], rargs...)
@inline first_fieldvector_in_bc(args::Tuple{}, rargs...) = nothing
@inline first_fieldvector_in_bc(x) = nothing
@inline first_fieldvector_in_bc(x::FieldVector) = x

@inline first_fieldvector_in_bc(
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
) = first_fieldvector_in_bc(bc.args)

@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple,
rargs...,
) where {TStart} =
truesofar &&
_is_diagonal_bc(truesofar, TStart, args[1], rargs...) &&
_is_diagonal_bc_args(truesofar, TStart, Base.tail(args), rargs...)

@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple{Any},
rargs...,
) where {TStart} =
truesofar && _is_diagonal_bc(truesofar, TStart, args[1], rargs...)
@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple{},
rargs...,
) where {TStart} = truesofar

@inline function _is_diagonal_bc(
truesofar,
::Type{TStart},
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
) where {TStart}
return truesofar && _is_diagonal_bc_args(truesofar, TStart, bc.args)
end

@inline _is_diagonal_bc(
truesofar,
::Type{TStart},
::TStart,
) where {TStart <: FieldVector} = true
@inline _is_diagonal_bc(
truesofar,
::Type{TStart},
x::FieldVector,
) where {TStart} = false
@inline _is_diagonal_bc(truesofar, ::Type{TStart}, x) where {TStart} = truesofar

# Find the first fieldvector in the broadcast expression (BCE),
# and compare against every other fieldvector in the BCE
@inline is_diagonal_bc(bc::Base.Broadcast.Broadcasted{FieldVectorStyle}) =
_is_diagonal_bc_args(true, typeof(first_fieldvector_in_bc(bc)), bc.args)

# Specialize on FieldVectorStyle to avoid inference failure
# in fieldvector broadcast expressions:
# https://github.com/JuliaArrays/BlockArrays.jl/issues/310
function Base.Broadcast.instantiate(
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
)
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
axes = Base.Broadcast.combine_axes(bc.args...)
else
axes = bc.axes
# Base.Broadcast.check_broadcast_axes is type-unstable
# for broadcast expressions with multiple fieldvectors.
# So, let's statically elide this when we have "diagonal"
# broadcast expressions:
if !is_diagonal_bc(bc)
Base.Broadcast.check_broadcast_axes(axes, bc.args...)
charleskawczynski marked this conversation as resolved.
Show resolved Hide resolved
end
end
return Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, axes)
end

@inline function Base.copyto!(
dest::FieldVector,
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
Expand Down
42 changes: 42 additions & 0 deletions test/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,48 @@ end
@test Y.k.z === 3.0
end

# https://github.com/CliMA/ClimaCore.jl/issues/1465
@testset "Diagonal FieldVector broadcast expressions" begin
FT = Float64
device = ClimaComms.device()
comms_ctx = ClimaComms.context(device)
cspace = TU.CenterExtrudedFiniteDifferenceSpace(FT; context = comms_ctx)
fspace = TU.FaceExtrudedFiniteDifferenceSpace(FT; context = comms_ctx)
cx = Fields.fill((; a = FT(1), b = FT(2)), cspace)
cy = Fields.fill((; a = FT(1), b = FT(2)), cspace)
fx = Fields.fill((; a = FT(1), b = FT(2)), fspace)
fy = Fields.fill((; a = FT(1), b = FT(2)), fspace)
Y1 = Fields.FieldVector(; x = cx, y = cy)
Y2 = Fields.FieldVector(; x = cx, y = cy)
Y3 = Fields.FieldVector(; x = cx, y = cy)
Y4 = Fields.FieldVector(; x = cx, y = cy)
Z = Fields.FieldVector(; x = fx, y = fy)
function test_fv_allocations!(X1, X2, X3, X4)
@. X1 += X2 * X3 + X4
return nothing
end
test_fv_allocations!(Y1, Y2, Y3, Y4)
p_allocated = @allocated test_fv_allocations!(Y1, Y2, Y3, Y4)
if device isa ClimaComms.AbstractCPUDevice
@test p_allocated == 0
elseif device isa ClimaComms.CUDADevice
@test_broken p_allocated == 0
end

bc1 = Base.broadcasted(
:-,
Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y2)),
Base.broadcasted(:*, 3, Y3),
)
bc2 = Base.broadcasted(
:-,
Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y1)),
Base.broadcasted(:*, 3, Z),
)
@test Fields.is_diagonal_bc(bc1)
@test !Fields.is_diagonal_bc(bc2)
end

function call_getcolumn(fv, colidx)
@allowscalar fvcol = fv[colidx]
nothing
Expand Down
Loading