Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2nd order AD fails #298

Open
avik-pal opened this issue Nov 20, 2024 · 4 comments
Open

2nd order AD fails #298

avik-pal opened this issue Nov 20, 2024 · 4 comments

Comments

@avik-pal
Copy link
Collaborator

using Reactant, Enzyme, Lux, Random, LinearAlgebra

const xdev = reactant_device()
const cdev = cpu_device()

model = Dense(5 => 5, gelu);
ps, st = Lux.setup(Random.default_rng(), model) |> xdev;
potential = StatefulLuxLayer{true}(model, ps, st)

# Currently EnzymeMLIR doesn't support batching so we force chunksize to 1
function ∇potential(potential, x)
    J = reshape(only(Enzyme.jacobian(Forward, potential, x; chunk=Val(1))), :, length(x))
    J_diag = @allowscalar diag(J)
    return reshape(J_diag, size(x))
end

function ∇²potential(potential, x)
    J = reshape(only(
        Enzyme.jacobian(Forward, Base.Fix1(∇potential, potential), x; chunk=Val(1))
    ), :, length(x))
end

x_ra = randn(Float32, 5, 3) |> xdev

@code_hlo ∇²potential(potential, x_ra)

A non-minimal example taken from LuxDL/Lux.jl#614

@avik-pal
Copy link
Collaborator Author

Error Msg

ERROR: AssertionError: Base.isconcretetype(typ)
Stacktrace:
  [1] abs_typeof(arg::LLVM.LoadInst, partial::Bool, seenphis::Set{LLVM.PHIInst})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/absint.jl:557
  [2] abs_typeof(arg::LLVM.LoadInst)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/absint.jl:283
  [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7066
  [4] codegen
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:6146 [inlined]
  [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468
  [6] _thunk
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468 [inlined]
  [7] cached_compilation
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8509 [inlined]
  [8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8641
  [9] #s2105#19135
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8778 [inlined]
 [10] 
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:707
 [12] autodiff
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:633 [inlined]
 [13] autodiff
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:512 [inlined]
 [14] macro expansion
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2090 [inlined]
 [15] gradient(::ForwardMode{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…}; chunk::Val{…}, shadows::Tuple{…})
    @ Enzyme /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:1970
 [16] #jacobian#133
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2177 [inlined]
 [17] jacobian
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2176 [inlined]
 [18] ∇potential(potential::StatefulLuxLayer{…}, x::Reactant.TracedRArray{…})
    @ Main /mnt/software/lux/Lux.jl/docs/src/manual/nested_autodiff_reactant.md:17
 [19] Fix1
    @ ./operators.jl:1127 [inlined]
 [20] #apply#24
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:37 [inlined]
 [21] apply
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:36 [inlined]
 [22] (::Tuple{})(none::Base.Fix1{typeof(∇potential), StatefulLuxLayer{…}}, none::Tuple{Reactant.TracedRArray{…}})
    @ Base.Experimental ./<missing>:0
 [23] (::Reactant.var"#32#42"{Bool, Bool, typeof(Reactant.apply), Tuple{}, Vector{}, Tuple{}})()
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:148
 [24] block!(f::Reactant.var"#32#42"{}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [25] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:120
 [26] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:40 [inlined]
 [27] #make_mlir_fn#25
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:53 [inlined]
 [28] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:40 [inlined]
 [29] overload_autodiff(::ForwardMode{…}, f::Const{…}, ::Type{…}, args::Duplicated{…})
    @ Reactant /mnt/software/lux/Reactant.jl/src/Interpreter.jl:373
 [30] autodiff
    @ /mnt/software/lux/Reactant.jl/src/Interpreter.jl:660 [inlined]
 [31] autodiff
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:512 [inlined]
 [32] macro expansion
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2090 [inlined]
 [33] gradient(::ForwardMode{…}, ::Base.Fix1{…}, ::Reactant.TracedRArray{…}; chunk::Val{…}, shadows::Tuple{…})
    @ Enzyme /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:1970
 [34] #jacobian#133
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2177 [inlined]
 [35] jacobian
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2176 [inlined]
 [36] ∇²potential
    @ /mnt/software/lux/Lux.jl/docs/src/manual/nested_autodiff_reactant.md:23 [inlined]
 [37] (::Tuple{})(none::StatefulLuxLayer{…}, none::Reactant.TracedRArray{…})
    @ Base.Experimental ./<missing>:0
 [38] (::Reactant.var"#32#42"{Bool, Bool, typeof(∇²potential), Tuple{}, Vector{}, Tuple{}})()
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:157
 [39] block!(f::Reactant.var"#32#42"{}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [40] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:120
 [41] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:40 [inlined]
 [42] #10
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:286 [inlined]
 [43] block!(f::Reactant.Compiler.var"#10#15"{typeof(∇²potential), Tuple{}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [44] #9
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:285 [inlined]
 [45] mmodule!(f::Reactant.Compiler.var"#9#14"{}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:92
 [46] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:282
 [47] compile_mlir!
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:281 [inlined]
 [48] #6
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:276 [inlined]
 [49] context!(f::Reactant.Compiler.var"#6#7"{@Kwargs{}, typeof(∇²potential), Tuple{}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
 [50] compile_mlir(f::Function, args::Tuple{StatefulLuxLayer{…}, ConcreteRArray{…}}; kwargs::@Kwargs{optimize::Bool})
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:274
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2024

just for fun what if you do set_abi(Forward, ReactantABI)

@avik-pal
Copy link
Collaborator Author

That did work!

@wsmoses
Copy link
Member

wsmoses commented Nov 22, 2024

yeah so this is again stemming from "any abstract interpreter based shenanigans fails to go through type unstable code".

Here the actual resolution we did earlier is to make Forward be replaced by set_abi(Forward, ReactantABI) in our absint. This makes things way nicer (including doing the replacement at the callsite of autodiff/jacobian/etc), so any intermediates that are type unstable don't have any issues. Similarly, it means we can natively call it like above. Unfortunately this only applies at the top level absint.

Probably the solution here is to have the absint replace type unstable calls with my_call(...) which itself runs things again in an absint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants