Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capture DifferentiationInterface calls for efficient Nested AD #600

Open
avik-pal opened this issue Apr 24, 2024 · 11 comments
Open

Capture DifferentiationInterface calls for efficient Nested AD #600

avik-pal opened this issue Apr 24, 2024 · 11 comments

Comments

@avik-pal
Copy link
Member

We capture:

  1. ForwardDiff.gradient
  2. ForwardDiff.jacobian
  3. Zygote.gradient
  4. Zygote.jacobian

after #598. We should capture the DI jacobian, gradient, and, most importantly pullback calls to augment them with the faster versions.

An important question here is where we should switch all calls or only calls with SecondOrder. I prefer the former, where we can just use forwarddiff to do the AD. Maybe for SecondOrder we respect the user choice.

@gdalle
Copy link
Contributor

gdalle commented Apr 27, 2024

Commenting so I keep track of this. If you think something deserves to be in DI, let me know!
cc @adrhill

@avik-pal
Copy link
Member Author

@gdalle what do you think about the last part in https://discourse.julialang.org/t/ann-lux-jl-explicitly-parameterized-neural-networks-in-julia/81689/65?u=avikpal?

If that exists in DI, I can just unwrap StatefulLuxLayer into that DI struct and forward the call

@avik-pal
Copy link
Member Author

avik-pal commented May 5, 2024

Seems like overloading the calls for custom functions won't really work because of ambiguity issues:

image

@prbzrg
Copy link
Contributor

prbzrg commented May 7, 2024

Can't we just specialize on Lux.StatefulLuxLayer and pass the arguments to Lux.vector_jacobian_product for DifferentiationInterface.pullback?

@avik-pal
Copy link
Member Author

avik-pal commented May 7, 2024

I was trying that for the gradient calls but DI specializes on the extras type which means we will also have to specialize on each extras for all backends

@gdalle
Copy link
Contributor

gdalle commented Jun 3, 2024

To support second order for Enzyme, I introduced DifferentiationInterface.nested(::AbstractADType) in gdalle/DifferentiationInterface.jl#285. The idea is that it returns a possibly different version of the backend object, which is aware that it is being differentiated. At the moment it doesn't do anything, except for AutoEnzyme which is turned into a homemade AutoDeferredEnzyme.
Would this be useful functionality for Lux.jl and friends? Should I make it public / work on it some more?
One could imagine an extension where nested tells the inner backend what outer backend is trying to differentiate through it.

@prbzrg
Copy link
Contributor

prbzrg commented Jun 3, 2024

If I understand correctly, Lux handles nested AD implicitly by replacing the calls (#598) and explicitly with vector_jacobian_product and jacobian_vector_product.
@gdalle Can DifferentiationInterface.nested resolve the need for them? (assuming that everyone only use DI, not APIs of each package)

@gdalle
Copy link
Contributor

gdalle commented Jun 3, 2024

I'm not sure, cause there are several things one might want to do with nested backends, and depending on the situation this lux replacement trick may not always be appropriate?

@gdalle
Copy link
Contributor

gdalle commented Jun 3, 2024

Just putting it out there in case Avik is inspired. Essentially, modifying the backend is the cleanest approach I could think of for this type of problem

@avik-pal
Copy link
Member Author

avik-pal commented Jun 4, 2024

To clarify how nested AD works in Lux: It doesn't simply switch the backends, i.e. we don't take a Zygote.gradient(Zygote.gradient(...)...) call and make it ForwardDiff.gradient(Zygote.gradient(...)...), you could in principle do that but you shouldn't (doing that would be computationally terrible). Instead, it changes the operations to a JVP over a gradient. Now, just extend that to Jacobians, JVPs, VJPs, etc.

The only case where replacement is not ideal is ForwardDiff.gradient(ForwardDiff.gradient(...)) where the problem size is extremely small, but we don't replace that anyway.

All the other forms of Zygote over ForwardDiff or Zygote over Zygote (or any reverse mode over X-mode) have no computational benefit and will error in most cases, so it does make sense to switch.

Even doing an Enzyme.Reverse over Enzyme.Reverse will be a bad idea just because of the overhead of reverse mode1. Basically, for 2nd order (not general nested higher-order AD), it is almost certainly beneficial to switch the operations.

Footnotes

  1. Okay, it might be faster if the reverse mode is calling into one of the vendor-specific codes and the forward mode isn't, but that is mostly because we got lazy.

@gdalle
Copy link
Contributor

gdalle commented Jun 4, 2024

Oh right, my nested trick works because I needed to change the behavior of the inner backend, but here you change the behavior of the outer backend when a gradient is already happening inside. I honestly don't know if there is a nice way to integrate this in DI, especially because we don't handle multiple parameters atm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants