diff --git a/src/Tracing.jl b/src/Tracing.jl index b6037d7c..30a617cc 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -563,7 +563,29 @@ end @inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=()) track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ()) + return to_rarray_internal(x, track_numbers) +end + +@inline function to_rarray_internal(@nospecialize(x), track_numbers::Tuple) return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers) end -to_rarray(x::ReactantPrimitive) = ConcreteRArray(x) +function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple) + return error("Cannot convert TracedRArray to ConcreteRArray") +end +@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x +@inline function to_rarray_internal( + @nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple +) + return ConcreteRArray(x) +end + +@inline to_rarray_internal(@nospecialize(x::ConcreteRNumber), ::Tuple) = x +@inline function to_rarray_internal( + @nospecialize(x::ReactantPrimitive), track_numbers::Tuple +) + for T in track_numbers + typeof(x) <: T && return ConcreteRNumber(x) + end + return x +end diff --git a/test/tracing.jl b/test/tracing.jl index d75a4359..88b1e766 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -100,4 +100,18 @@ using Test end end end + + @testset "specialized dispatches" begin + @test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray( + 1.0; track_numbers=(Number,) + ) isa ConcreteRNumber + @test @inferred Reactant.to_rarray(1.0) isa Float64 + @test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray + + x_ra = Reactant.to_rarray(rand(3)) + @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray + + x_ra = Reactant.to_rarray(1.0; track_numbers=(Number,)) + @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber + end end