-
-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AD (Zygote, ForwardDiff) compatibility? #91
Comments
Zygote is rather easy to fix for this. It just needs literal getproperty rules like https://github.com/SciML/RecursiveArrayTools.jl/blob/master/src/zygote.jl#L36-L41 which are actually just identity. |
Just to update this issue. The example given above works now for ForwardDiff, but Zugote still fails with the same error. Great to see some progress! [f6369f11] ForwardDiff v0.10.22
[2ee39098] LabelledArrays v1.6.5
[e88e6eb3] Zygote v0.6.29 |
To address the particular error from above, I believe the following should do it: using ChainRulesCore
function ChainRulesCore.rrule(::typeof(getproperty), A::LArray, s::Symbol)
function getproperty_LArray_adjoint(d)
# NOTE: I hope this reference to `A` is optimized away.
Δ = similar(A) .= 0
setproperty!(Δ, s, d)
return (NoTangent(), Δ, NoTangent())
end
return getproperty(A, s), getproperty_LArray_adjoint
end You might also run into issues with missing adjoints for constructor when using Zygote with something like function ChainRulesCore.rrule(::Type{LArray{S}}, x::AbstractArray) where {S}
# Sometimes we're pulling back gradients which are not `LArray`.
constructor_LArray_adjoint(Δx::AbstractArray) = NoTangent(), Δx
constructor_LArray_adjoint(Δlx::LArray) = NoTangent(), Δlx.__x
return LArray{S}(x), constructor_LArray_adjoint
end You'd also need a similar one for |
@torfjelde fantastic. Could you open a PR with those? |
Sure can 👍 |
LabelledArrays looks very promising to simplify model definitions. Unfortunately it seem largely incompatible with AD. Is there a good workaround?
(In case of
ForwardDiff
this is probably related to #68)The text was updated successfully, but these errors were encountered: