Skip to content

Commit

Permalink
feat: specialize dispatches for faster concrete array generation (#213)
Browse files Browse the repository at this point in the history
* feat: specialize dispatches for faster concrete array generation

* chore: apply formatting suggestion

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
avik-pal and github-actions[bot] authored Nov 1, 2024
1 parent b6ee968 commit a17315c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,29 @@ end

@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=())
track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ())
return to_rarray_internal(x, track_numbers)
end

@inline function to_rarray_internal(@nospecialize(x), track_numbers::Tuple)
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers)
end

to_rarray(x::ReactantPrimitive) = ConcreteRArray(x)
function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple)
return error("Cannot convert TracedRArray to ConcreteRArray")
end
@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x
@inline function to_rarray_internal(
@nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple
)
return ConcreteRArray(x)
end

@inline to_rarray_internal(@nospecialize(x::ConcreteRNumber), ::Tuple) = x
@inline function to_rarray_internal(
@nospecialize(x::ReactantPrimitive), track_numbers::Tuple
)
for T in track_numbers
typeof(x) <: T && return ConcreteRNumber(x)
end
return x
end
14 changes: 14 additions & 0 deletions test/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,18 @@ using Test
end
end
end

@testset "specialized dispatches" begin
@test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray(
1.0; track_numbers=(Number,)
) isa ConcreteRNumber
@test @inferred Reactant.to_rarray(1.0) isa Float64
@test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray

x_ra = Reactant.to_rarray(rand(3))
@test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray

x_ra = Reactant.to_rarray(1.0; track_numbers=(Number,))
@test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber
end
end

1 comment on commit a17315c

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: a17315c Previous: b6ee968 Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1263749401 ns 1322156738 ns 0.96
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1254668396 ns 1293942538 ns 0.97
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1218277318 ns 1224868312 ns 0.99
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2376495016 ns 2323944334 ns 1.02
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 217726580 ns 216612531 ns 1.01
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 7226166416 ns 6954798003 ns 1.04
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5511150207 ns 5103509804 ns 1.08
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5102020848 ns 5081171584 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 6993217459 ns 6720851214 ns 1.04
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 38085761917 ns 36264215655 ns 1.05
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1208392095 ns 1325295976 ns 0.91
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1331979590 ns 1316239703 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1228565001 ns 1223956642 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2452231772 ns 2499287891 ns 0.98
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8748209 ns 8665141 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1578057500 ns 1575352408 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1557311922 ns 1567227136 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1557684126 ns 1566092027.5 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 2769517816 ns 2878841123 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 3303048898.5 ns 2685299362 ns 1.23
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1303432996 ns 1239206197.5 ns 1.05
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1292627349.5 ns 1289136308 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1312140581.5 ns 1237433180 ns 1.06
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2608146101 ns 2746675598 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 22645472 ns 22719307 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2183323759 ns 2131005675 ns 1.02
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2161824787 ns 2126128561 ns 1.02
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2150773246 ns 2131061285 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3353554606 ns 3402150262 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 6032060527 ns 5740208504 ns 1.05
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1315388210 ns 1262392264.5 ns 1.04
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1313576758.5 ns 1258413265.5 ns 1.04
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1308732662.5 ns 1270552917.5 ns 1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2435356858 ns 2586048537 ns 0.94
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 6572926 ns 7031315 ns 0.93
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1416310529 ns 1421898963 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1409069455 ns 1430099101 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1410196431 ns 1422752680 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 2620146990 ns 2655576241 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1384443752 ns 1274970277 ns 1.09
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1325713657.5 ns 1274806307.5 ns 1.04
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1268777827.5 ns 1310390497 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1294207842.5 ns 1302121842 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2374603722 ns 2624706431 ns 0.90
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 12110782.5 ns 12297131 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1711411728 ns 1734648342 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1707811998 ns 1716005516 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1709512803 ns 1705670596 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 2924854567 ns 2930317875 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 2927891069 ns 3071789485.5 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1270178508 ns 1351363548 ns 0.94
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1317660758 ns 1300008762 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1263311709 ns 1285804476 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2584191843 ns 2521378317 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 27307540.5 ns 27302342 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2190938487 ns 2243314517 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2166284687 ns 2209743795 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2137987987 ns 2196379717 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3415666738 ns 3417637678 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 6038343271.5 ns 5737977502 ns 1.05
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1233854317 ns 1239628004 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1299829181.5 ns 1471450378 ns 0.88
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1226243251 ns 1188148551.5 ns 1.03
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2393640923 ns 2290586722 ns 1.04
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 52646968 ns 52692914.5 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3006477320 ns 2982530184 ns 1.01
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 2989128551 ns 2990386476 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3003357676 ns 3011396498 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4443262702 ns 4338309706 ns 1.02
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 24545735518 ns 11645205146 ns 2.11
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1288108103 ns 1216846465 ns 1.06
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1247053980 ns 1268232903.5 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1260403416 ns 1322945944 ns 0.95
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2513765600 ns 2578098657 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 70692019 ns 70862545 ns 1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3164689242 ns 3193348347 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3166667974 ns 3203590115 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3168332239 ns 3154476044 ns 1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 4510953172 ns 4523619517 ns 1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 12354970629 ns 9115055641 ns 1.36
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1242550154 ns 1289173742 ns 0.96
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1270011702 ns 1268715834.5 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1308184956.5 ns 1269718689.5 ns 1.03
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2564144412 ns 2796130400 ns 0.92
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 20737061 ns 20728567 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1846241603 ns 1845983114 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1845891211 ns 1840843333 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1838778303 ns 1844902802 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3067201183 ns 3070132545 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3142722042.5 ns 3473525524.5 ns 0.90

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.