Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Returning Broadcasted cotangents for Broadcasted arguments? #698

Open
ToucheSir opened this issue Mar 13, 2023 · 7 comments
Open

Returning Broadcasted cotangents for Broadcasted arguments? #698

ToucheSir opened this issue Mar 13, 2023 · 7 comments

Comments

@ToucheSir
Copy link
Contributor

More of a question on whether this makes sense than a feature request. I could see it replacing or complementing @thunk(<eager broadcast> of unbroadcast(...)) for rules such as https://github.com/JuliaDiff/ChainRules.jl/blob/v1.48.0/src/rulesets/Base/broadcast.jl#L174.

@oxinabox
Copy link
Member

Tangents need to be represented with types that support + to each other and to the primal, scalar multiplication, and zero. Ideally a bunch of other linear operations that make sense for them.
Boadcasted does not, so we can't.
Operations that assume they are getting an array (and that special case thunks by unthunking them) will break.

Potentially we could wrap it in a type and then do that.

Maybe in ChainRules 2.0 we should relax that requirement and export our own exp_map, accum, dzero etc that just fall back to +, +, and zero resp.
But I don't think that would help when it comes to something assuming you can do some other array linear alg operation on it.

@mcabbott
Copy link
Member

Broadcasted is a lazy object not unlike Thunk, we could add a BroadcastThunk <: AbstractThunk which wraps it & allows + etc, on which unthunk materialises. Many present InplaceThunks could instead be this.

What seems tricky is to efficiently consume any variety of thunks in downstream rules. If you're going to do dΩ .* ... then you do want broadcast fusion if you get dΩ::BroadcastThunk, but if you get an InplaceThunk which is in fact mul! then perhaps you want another path. And perhaps worse, should something like the rule for ./ assume that this dΩ::BroadcastThunk contains expensive operations (materialise once) or cheap (fuse into two downstream broadcasts / broadcast-reductions)?

(I agree that ditching the use of + in 2.0 for an explicit accum would make mutation etc. much easier. It would also mean we could scrap Tangent & just use NamedTuples by default; maybe only Zero and cases like Tangent{Dict} need be custom types.)

@ToucheSir
Copy link
Contributor Author

I think dedicated structural tangent types are still useful for cases like storing (co)tangents of mutable structs for rules which use setproperty!. That said, it would be nice to rely on natural tangents for struct types more where possible. I wonder if there's a way to approximate differentiable Swift's Differentable + TangentVector interface for that.

@oxinabox
Copy link
Member

oxinabox commented Mar 15, 2023

With the improvements in compiler analysis coming its likely we will be able to do away with Thunks before ChainRules 2.0.
Diffractor should be getting a component that works out what is used and only ADs those parts, and like removed unreleated code from rules before that.
Its a ways off yet, but I suspect it is going to have to be finished before I have time to work on ChainRules 2.0 anyway.
Though potentially only the forward mode bit might be done.

@mcabbott
Copy link
Member

mcabbott commented Mar 20, 2023

we will be able to do away with Thunks before ChainRules 2.0.

But such analysis would not remove the above desire for BroadcastThunk -- where the goal is to fuse the reverse pass & save memory. Compiler improvements to fuse broadcasting (in code which looks like a function returning an Array, used in another broadcast) are presumably further off.

It would be helpful to have some small examples where you might expect this to matter the most. Perhaps M = (v .- v')./2 is one -- fused forwards at present, but the gradient will materialise 1 matrix for the division, and another for the subtraction. To remove these, the Broadcasted would have to be digested by unbroadcast directly.

Edit: I have a messy prototype on a branch. One example is this:

julia> let x = rand(1000)
         @btime gradient(x -> sum((1 .- x) ./ 2), $x)  # Diffractor
         @btime copy($x)  # to compare
       end;
  min 2.838 μs, mean 7.916 μs (27 allocations, 32.83 KiB)  # before, 4 copies
  min 294.147 ns, mean 1.515 μs (1 allocation, 7.94 KiB)

  min 3.062 μs, mean 5.657 μs (28 allocations, 17.09 KiB)  # after, 2 copies

Zygote gets min 2.185 μs, mean 4.726 μs (20 allocations, 16.59 KiB) here, also 2 copies. Its gradient for sum uses Fill (disabled on GPU) which propagates through to be returned, so both copies are on the un-fused forward pass.

@ToucheSir
Copy link
Contributor Author

Maybe unbroadcast's behaviour should also be part of this BroadcastThunk? Then it becomes almost a dual to Broadcasted. The trick would be knowing when to materialize vs keep constructing the lazy unbroadcast computation. Assuming it's easy to ascertain how many inputs a Broadcasted covers, could the heuristic be to materialize for >=2 inputs and be lazy for 1 input?

@mcabbott
Copy link
Member

I pushed the branches now.

In this version, when unbroadcast has to perform a sum, it does so eagerly, but without materialising the Broadcasted. So (x .- x')./2 will allocate vectors, 2N, but avoid the matrix, N^2. When it does not have to sum, it passes the BroadcastThunk along, unless ProjectTo wants to reshape it:

https://github.com/JuliaDiff/ChainRules.jl/pull/705/files#diff-730f3517d23f24df68145516aec73a26702d9464ef1f2402a35d69fef75c85d2R338

I had forgotten, but sum(::Broadcasted; dims, init) is now fast, whereas sum(::Broadcasted) is slower than sum(::Array) as it uses cartesian indexing. The dims one needs init, which is partly why I made this BroadcastThunk{T}. Probably all of this should just be disabled for arrays of things more complicated than numbers.

This makes no attempt to track how many places the same BroadcastThunk is going to be used. So the idea would be to only use it for cheap operations. That may include the gradient in split_bc_forwards, but not yet -- for now just the same operations for which the forward pass is lazy (fused).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants