Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attach rule to mapfoldl_impl not foldl #569

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 15, 2022

Closes #567, perhaps in the minimal way, by attaching these rules to internal function which take positional arguments. Gradient for init is just @not_implemented for now.

One nice effect is that I think foldr may work too.

One weird effect is that accumulate!(f, y, x) will work, silently overwriting y. It does return a NotImplemented, maybe that helps. Xref #521

Non-vector shapes like accumulate(f, ::Matrix) take a different path, via Iterators.accumulate, and will miss the rule. So will accumulate(f, ::Tuple). Maybe for that case Base's code is OK.


Closes #672 . Probably closes FluxML/Zygote.jl#1297

@oxinabox
Copy link
Member

what are the TODOs that would make this nolonger [draft]?

@mcabbott
Copy link
Member Author

I have forgotten. But one was to decide how unhappy we are about this:

accumulate!(f, y, x) will work, silently overwriting y.

And in general about hooking on deep inside Base. I didn't see a nicer way to hook onto accumulate.

@ToucheSir
Copy link
Contributor

Following up on this, could the accumulate! behaviour be worked around by adding an explicit @opt_out for that function?

@mcabbott
Copy link
Member Author

mcabbott commented Jul 19, 2022

fff84b5 separates out the Tuple method, for which this way of writing the rule makes more sense. And handles init by a separate rule.

It is, strangely, much slower than tagged version, inside AD.

And it breaks Yota.

The functions between foldl and mapfoldl_impl are these, shouldn't they be easy for AD to work through?

foldl(op, itr; kw...) = mapfoldl(identity, op, itr; kw...)
mapfoldl(f, op, itr; init=_InitialValue()) = mapfoldl_impl(f, op, init, itr)
julia> xn = Tuple(randn(10));

julia> xm = Tuple(rand(10,10) for _ in 1:10);

# Zygote

julia> @btime Zygote.gradient(x -> foldl(/, x), $xn);
  min 47.318 ns, mean 77.223 ns (1 allocation, 16 bytes)  # before
  min 4.381 μs, mean 4.692 μs (37 allocations, 2.16 KiB)  # after -- 100x slower

julia> @btime Zygote.gradient(x -> sum(abs2, foldl(*, x)), $xm);
  min 17.667 μs, mean 77.964 μs (53 allocations, 29.52 KiB)  # before
  min 19.708 μs, mean 23.791 μs (69 allocations, 27.48 KiB)  # after

julia> @btime Zygote.gradient(x -> Base.afoldl(/, x...), $xn);  # no rule -- much slower
  min 130.500 μs, mean 135.423 μs (413 allocations, 16.33 KiB)

julia> @btime Zygote.gradient(x -> sum(abs2, Base.afoldl(*, x...)), $xm);
  min 143.500 μs, mean 151.413 μs (384 allocations, 40.30 KiB)

# Diffractor

julia> @btime Diffractor.gradient(x -> foldl(/, x), $xn);
  min 29.271 ns, mean 30.017 ns (0 allocations)              # before
  min 350.632 ns, mean 400.959 ns (6 allocations, 672 bytes) # after -- 10x slower

julia> @btime Diffractor.gradient(x -> sum(abs2, foldl(*, x)), $xm);
  min 13.666 μs, mean 16.422 μs (29 allocations, 25.38 KiB)      # before
  min 162.584 μs, mean 218.275 μs (357 allocations, 168.42 KiB); # after

julia> @btime Diffractor.gradient(x -> Base.afoldl(/, x...), $xn);  # no rule -- better than Zygote
  min 352.882 ns, mean 419.163 ns (6 allocations, 672 bytes)

julia> @btime Diffractor.gradient(x -> sum(abs2, Base.afoldl(/, x...)), $xm)
  min 163.125 μs, mean 204.721 μs (357 allocations, 168.42 KiB)

# Yota

julia> @btime Yota.grad(x -> foldl(/, x), $xn);
  min 182.790 ns, mean 657.142 ns (3 allocations, 208 bytes)  # before
ERROR: No deriative rule found for op %3 = foldl(/, %2)::Float64, try defining it... # after -- fails

julia> @btime Yota.grad(x -> sum(abs2, foldl(*, x)), $xm);
  min 8.583 μs, mean 50.186 μs (21 allocations, 16.19 KiB)

julia> Yota.grad(x -> Base.afoldl(/, x...), xn);
ERROR: syntax: Slot objects should not occur in an AST

# Checking pieces?

julia> yyy = Yota.YotaRuleConfig()

julia> @code_warntype rrule(yyy, foldl, /, xn)  # before
julia> @code_warntype rrule(yyy, foldl, /, xn)[2](1.0)

julia> @code_warntype rrule(yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), xn)  # after
julia> @code_warntype rrule(yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), xn)[2](1.0)

julia> @btime rrule($yyy, foldl, /, $xn)[2](1.0);
  min 29.271 ns, mean 30.036 ns (0 allocations)

julia> @btime rrule($yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), $xn)[2](1.0);
  min 29.271 ns, mean 29.753 ns (0 allocations)

could the accumulate! behaviour be worked around by adding an explicit @opt_out for that function?

I don't see how. I think you can only opt out of functions which have rules, and those ones need to be called to work.

This needs JuliaDiff/ChainRulesCore.jl#567 now merged.

9af7a64 also adds mapfoldl, as map then foldl, for tuples.

Maybe also worth noting, moving the rule to _accumulate! means no such rule for tuples. But @less accumulate(/, (1,2,3)) shows this is pretty simple, and calls Base.afoldl. Perhaps the tuple foldl rule should be applied to Base.afoldl too (or v-v).

@mcabbott
Copy link
Member Author

Trying a bit to track this down, today, I think the slowdown is just some quirk of Zygote's handling of keywords. So it's not the rule's fault. And anything which fixes the init problem will probably hit it. Diffractor no longer sees the slowdown seen above:

using Diffractor, ChainRulesCore
ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing
# Solves same error as   https://github.com/JuliaDiff/ChainRulesCore.jl/pull/503
xn = Tuple(randn(10));

@btime Diffractor.gradient(x -> foldl(/, x), $xn);
#   min 29.313 ns, mean 29.545 ns (0 allocations)  before (old rule on foldl)
#   min 29.313 ns, mean 29.522 ns (0 allocations)  after (new rule on Base.mapfoldl_impl)

@btime Diffractor.gradient(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
#   min 47.625 μs, mean 53.596 μs (569 allocations, 33.16 KiB)  before -- i.e. with no rule, just Base, NB μs
_foldl(op::G, itr; kw...) where {G} = _mapfoldl(identity, op, itr; kw...)
_mapfoldl(f::F, op::G, itr; init=Base._InitialValue()) where {F,G} = Base.mapfoldl_impl(f, op, init, itr)
@btime Diffractor.gradient(x -> _foldl(/, x), $xn);
#   min 56.542 μs, mean 62.279 μs (672 allocations, 38.78 KiB)  before -- i.e. with no rule, just Base, NB μs

import Zygote
@btime Zygote.gradient(x -> foldl(/, x), $xn);
#   min 47.402 ns, mean 48.592 ns (1 allocation, 16 bytes)  before
#   min 4.482 μs, mean 9.120 μs (37 allocations, 2.16 KiB)  after -- this I didn't like, above
# Same with Zygote#master, thus including https://github.com/FluxML/Zygote.jl/pull/1286

@btime Zygote.gradient(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
#   min 152.667 μs, mean 157.707 μs (494 allocations, 26.44 KiB)  before -- i.e. using no rule, jus Base, NB μs
#   min 47.402 ns, mean 82.826 ns (1 allocation, 16 bytes)  after -- so the issue is Zygote & keywords

using Yota
@btime Yota.grad(x -> foldl(/, x), $xn);
#   min 235.140 ns, mean 251.834 ns (3 allocations, 208 bytes)  before
# error afterwards, doesn't track further?
ChainRulesCore.@non_differentiable Base._InitialValue()
@btime Yota.grad(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
#  min 231.805 ns, mean 250.267 ns (3 allocations, 208 bytes)  after

So I think we should merge this, if tests pass etc.

@mcabbott mcabbott marked this pull request as ready for review August 19, 2022 05:08
@ToucheSir
Copy link
Contributor

ToucheSir commented Aug 19, 2022

Zygote tries to diff through the kwsorter definition (i.e. https://docs.julialang.org/en/v1/devdocs/functions/#Keyword-arguments), which includes control flow. It's very difficult to make this type stable because it requires saving a different set of pullbacks for each branch (does anybody know how does Diffractor does this?), but FluxML/Zygote.jl#1195 might help with runtime overhead.

init, x
end
hobbits = accumulate(list; init=(start, nothing)) do (a, _), b
c, back = rrule_via_ad(config, op, a, b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to not capture the accumulated outputs (cs) in the pullback? It seems easy enough for tuples using map, but I'm unsure if the extra allocation would be welcomed for arrays.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you can write a for loop like this: FluxML/Zygote.jl#644 (comment) . IMO this array method should probably be replaced, but not today.

Carrying c by updating a variable from inside accumulate was very slow, IIRC it hits the closure issue.

Copy link
Contributor

@ToucheSir ToucheSir Aug 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried that in FluxML/Flux.jl#2003. The main challenges are nested differentiation and handling the case when typeof(x |> f) != typeof(x |> f |> f) (you must widen, which means preallocating an array is impossible without return_type shenanigans).

So, assuming type inference cooperates, the accumulate approach seems no less promising. Would there be any objections to a post-processing step like the following which allows the GC to clean up intermediate outputs before the pullback?

# ... y = first(last(hobbits))
# If outputs are (recursively) allocated inline, we're less worried about memory overhead
# and the GC can't free them individually anyways. 
if !isbitstype(eltype(hobbits))
  hobbits = map(((_, pb)) -> (nothing, pb),  hobbits)
end
# axe = axes(x) ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the mutation has these problems, but was much quicker, maybe it can be used when safe.

The intention with writing foldl in terms of accumulate was to allow for 2nd derivatives, but not sure this actually works right now.

Re saving memory, we can add something like unzip_accumulate(f, xs; init) = StructArrays.components(StructArray(Iterators.accumulate(f, xs; init))) to free the bits we don't need anymore.

julia> accumulate([1,2,3], init=(4,5)) do prev, this
         this .+ prev
       end
3-element Vector{Tuple{Int64, Int64}}:
 (5, 6)
 (7, 8)
 (10, 11)

julia> unzip_accumulate([1,2,3], init=(4,5)) do prev, this
         this .+ prev
       end
([5, 7, 10], [6, 8, 11])

But this PR would like to kick the can down the road on such improvements.

(And others -- it returns @not_implemented for accumulate's init, might be easy to do better, but tired of adding tests... at least it's no longer wrong.)

@ToucheSir
Copy link
Contributor

ToucheSir commented Aug 20, 2022

After looking into Diffractor, I think whatever it does happens outside the actual AD transform (perhaps leaving control flow intact is enough), but the ability to have unused branches/blocks in the keyword sorter pruned in the final IR does wonders for type stability. Inspired by this, FluxML/Zygote.jl#446 (comment) has some thoughts on how we might do something similar there.

@mcabbott
Copy link
Member Author

mcabbott commented Aug 20, 2022

The remaining test failure is 1.8 on x86:

Testing rulesets/LinearAlgebra/structured.jl:
[181](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:184)
170.036836 seconds (119.56 M allocations: 3.781 GiB, 9.09% gc time, 97.04% compilation time)
[182](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:185)
Testing rulesets/LinearAlgebra/symmetric.jl:
[183](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:186)
terminate called after throwing an instance of 'std::bad_alloc'
[184](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:187)
  what():  std::bad_alloc
[185](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:188)

[186](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:189)
signal (6): Aborted
[187](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:190)
in expression starting at /home/runner/work/ChainRules.jl/ChainRules.jl/test/rulesets/LinearAlgebra/symmetric.jl:1
[188](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:191)
__kernel_vsyscall at linux-gate.so.1 (unknown line)
[189](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:192)
gsignal at /lib32/libc.so.6 (unknown line)
[190](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:193)

Also happened https://github.com/JuliaDiff/ChainRules.jl/runs/7933271950?check_suite_focus=true with #667 (no longer needed). Or:

Testing rulesets/LinearAlgebra/symmetric.jl:
[183](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:186)
Internal error: encountered unexpected error in runtime:
[184](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:187)
OutOfMemoryError()
[185](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:188)
terminate called after throwing an instance of 'std::bad_alloc'
[186](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:189)
  what():  std::bad_alloc
[187](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:190)

[188](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:191)
signal (6): Aborted
[189](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:192)
in expression starting at /home/runner/work/ChainRules.jl/ChainRules.jl/test/rulesets/LinearAlgebra/symmetric.jl:1
[190](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:193)
ERROR: LoadError: Package ChainRules errored during testing (received signal: 6)
[191](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:194)

@mcabbott
Copy link
Member Author

Can this be merged?

It's not the last word, as noted above, but it is a step forwards.


#####
##### `accumulate`
#####

# Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`.

# Also like `foldl`, the version with a keyword `init` can't easily be given a gradient.
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we don't at present support getting back a gradient for init, except if it's nothing and then it doesn't matter.
So we might as well put this on accumulate?

I am a little uncomfortable putting rules on mutating fuctions.
Though perhaps this one is safe as we are always fully overwriting y and never reading it?
A comment to that effect would be good if so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention was to move both to functions with positional init, and this mutating function was the best option I could find in Base's dispatch.

Then it could have a gradient for init. I didn't get around to writing one, mostly got tired of fighting tests. But at least step 1 makes step 2 easier, it would be a small PR. And for now it returns @not_implemented which is better than a silent zero, in theory at least.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think this is unsafe the same way that fill! is unsafe. Except that in practice, I think it's much less likely to cause problems, as anyone who gets to accumulate! has probably been trained out of hoping that mutation will work.

The originally envisaged use case was that the 2nd derivative of foldl would involve this accumulate gradient. But I don't recall whether I ever checked that.

else
x, init
something(init), x
end
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
# "The Hobbit", or "There and Back Again"
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can resolve the comment about why we do _accumulate! and if it is actually safe, i think we can merge this.
If it breaks stuff we can always revert

src/rulesets/Base/mapreduce.jl Outdated Show resolved Hide resolved
src/rulesets/Base/mapreduce.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member

oxinabox commented Mar 7, 2023

If you think this is good to go then we can merge it.
If we see it is breaking things we can revert it

Comment on lines +536 to +539
if init === Base._InitialValue() # `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
end
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if init === Base._InitialValue() # `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
end
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
if init === Base._InitialValue() # `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
d_init = NoTangent()
else
d_init = trio[end][2]
end

Would this work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably!

It's been a while, but my memory is that I mostly got tired of making tests, so thought I'd leave that for later.

Co-authored-by: Frames White <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants