Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Externalize gradient computations to DifferentiationInterface.jl? #544

Open
gdalle opened this issue Mar 18, 2024 · 9 comments
Open

Externalize gradient computations to DifferentiationInterface.jl? #544

gdalle opened this issue Mar 18, 2024 · 9 comments
Labels

Comments

@gdalle
Copy link
Contributor

gdalle commented Mar 18, 2024

Hey there Avik!

As you may know, I have been busy developing DifferentiationInterface.jl, and it's really starting to take shape.
I was wondering if it would be useful for Lux.jl as a dependency, in order to support a wider variety of autodiff backends defined by ADTypes.jl?

Looking at the code, it seems the main spot where AD comes up (beyond the docs and tutorials) is Lux.Training:

function compute_gradients(ad::ADTypes.AbstractADType, ::Function, _, ::TrainState)
return __maybe_implemented_compute_gradients(ad)
end
function __maybe_implemented_compute_gradients(::T) where {T <: ADTypes.AbstractADType}
throw(ArgumentError("Support for AD backend $(nameof(T)) has not been implemented yet!!!"))
end
function __maybe_implemented_compute_gradients(::ADTypes.AutoZygote)
throw(ArgumentError("Load `Zygote` with `using Zygote`/`import Zygote` before using this function!"))
end
function __maybe_implemented_compute_gradients(::ADTypes.AutoTracker)
throw(ArgumentError("Load `Tracker` with `using Tracker`/`import Tracker` before using this function!"))
end

Gradients are only implemented in the extensions for Zygote and Tracker:

function Lux.Experimental.compute_gradients(
::AutoZygote, objective_function::Function, data, ts::Lux.Experimental.TrainState)
(loss, st, stats), back = Zygote.pullback(
ps -> objective_function(ts.model, ps, ts.states, data), ts.parameters)
grads = back((one(loss), nothing, nothing))[1]
@set! ts.states = st
return grads, loss, stats, ts
end

function Lux.Experimental.compute_gradients(
::AutoTracker, objective_function::Function, data, ts::Lux.Experimental.TrainState)
ps_tracked = fmap(param, ts.parameters)
loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data)
back!(loss)
@set! ts.states = st
grads = fmap(Tracker.grad, ps_tracked)
return grads, loss, stats, ts
end

While DifferentiationInterface.jl is not yet ready or registered, it has a few niceties like Enzyme support which might pique your interest. I'm happy to discuss with you and see what other features you might need.

The main one I anticipate is compatibility with ComponentArrays.jl (gdalle/DifferentiationInterface.jl#54), and I'll try to add it soon.

cc @adrhill

@avik-pal
Copy link
Member

avik-pal commented Mar 21, 2024

Yeah, I am all for purging that code, depending on DifferentiationInterface. The major things that need to work:

  1. Tracker
  2. ComponentArrays (probably just testing)

That should be enough to get a start.

@gdalle
Copy link
Contributor Author

gdalle commented Mar 21, 2024

Tracker works and is tested, ComponentArrays is my next target

@avik-pal
Copy link
Member

That's great. This week is a bit busy but I can try possibly early next month

@gdalle
Copy link
Contributor Author

gdalle commented Mar 21, 2024

No rush! And always happy to help debug.
We're also chatting with Flux to figure out how best to support their use case, which is slightly more complex for lack of a dedicated parameter type like ComponentVector.

@avik-pal
Copy link
Member

Do you require a vector input mandatorily? ComponentArrays has an overhead for smallish arrays (see #49), so having an fmap based API might be good, though that is not really a big blocker.

@gdalle
Copy link
Contributor Author

gdalle commented Mar 21, 2024

At the moment yes. We're thinking about how to be more flexible in order to accommodate Flux's needs, you can track gdalle/DifferentiationInterface.jl#87 to see how it evolves

@gdalle
Copy link
Contributor Author

gdalle commented Apr 28, 2024

DI v0.3 should be out sometime next week (I've been busy with sparse Jacobians & Hessians), but I don't think I'll have much time in the near future to revamp the Lux tests. Still, I think it would make sense to offer DI at least as a high level interface, even if it is not yet used in the package internals / tests. It might also help you figure out #605

@gdalle
Copy link
Contributor Author

gdalle commented Apr 28, 2024

Note that for DI to work in full generality with ComponentArrays, I need jonniedie/ComponentArrays.jl#254 to be fixed. Otherwise Jacobians and Hessians will stay broken (the rest, in particular gradient, is independent from stacking)

@avik-pal
Copy link
Member

Yes I want to roll it out first as a high level interface when the inputs are AbstractArray rather than for testing. I still want to support arbitrary structured parameters as input for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants