Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD (Zygote, ForwardDiff) compatibility? #91

Open
scheidan opened this issue Mar 19, 2021 · 5 comments
Open

AD (Zygote, ForwardDiff) compatibility? #91

scheidan opened this issue Mar 19, 2021 · 5 comments

Comments

@scheidan
Copy link

LabelledArrays looks very promising to simplify model definitions. Unfortunately it seem largely incompatible with AD. Is there a good workaround?

using LabelledArrays
import ForwardDiff
import Zygote

model(p) =  p.a + p.b^2 + p.c^3

p = LVector(a=1, b=2, c=3)
ps = SLVector(a=1, b=2, c=3)
model(ps)
model(p)

ForwardDiff.gradient(model, p) # works :)
Zygote.gradient(model, p)      # ERROR: ArgumentError: invalid index: Val{:c}() of type Val{:c}

ForwardDiff.gradient(model, ps) # ERROR: type SArray has no field a
Zygote.gradient(model, ps)      # ERROR: ArgumentError: invalid index: Val{:c}() of type Val{:c}

(In case of ForwardDiff this is probably related to #68)

@ChrisRackauckas
Copy link
Member

Zygote is rather easy to fix for this. It just needs literal getproperty rules like

https://github.com/SciML/RecursiveArrayTools.jl/blob/master/src/zygote.jl#L36-L41

which are actually just identity.

@scheidan
Copy link
Author

scheidan commented Nov 8, 2021

Just to update this issue. The example given above works now for ForwardDiff, but Zugote still fails with the same error.

Great to see some progress!

  [f6369f11] ForwardDiff v0.10.22
  [2ee39098] LabelledArrays v1.6.5
  [e88e6eb3] Zygote v0.6.29

@torfjelde
Copy link
Contributor

To address the particular error from above, I believe the following should do it:

using ChainRulesCore

function ChainRulesCore.rrule(::typeof(getproperty), A::LArray, s::Symbol)
    function getproperty_LArray_adjoint(d)
        # NOTE: I hope this reference to `A` is optimized away.
        Δ = similar(A) .= 0
        setproperty!(Δ, s, d)
        return (NoTangent(), Δ, NoTangent())
    end
    return getproperty(A, s), getproperty_LArray_adjoint
end

You might also run into issues with missing adjoints for constructor when using Zygote with something like @LArray ..., in which case the following should fix your issue:

function ChainRulesCore.rrule(::Type{LArray{S}}, x::AbstractArray) where {S}
    # Sometimes we're pulling back gradients which are not `LArray`.
    constructor_LArray_adjoint(Δx::AbstractArray) = NoTangent(), Δx
    constructor_LArray_adjoint(Δlx::LArray) = NoTangent(), Δlx.__x
    return LArray{S}(x), constructor_LArray_adjoint
end

You'd also need a similar one for SLArray.

@ChrisRackauckas
Copy link
Member

@torfjelde fantastic. Could you open a PR with those?

@torfjelde
Copy link
Contributor

Sure can 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants