Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rrule for mean(f, x) is not vectorized? #733

Open
Red-Portal opened this issue Aug 17, 2023 · 2 comments
Open

rrule for mean(f, x) is not vectorized? #733

Red-Portal opened this issue Aug 17, 2023 · 2 comments
Labels

Comments

@Red-Portal
Copy link

Hi, it seems that the rrule for mean(f, x) is not vectorized and thus does not place nicely with CUDA:

using Zygote, CUDA, Statistics

julia> gradient(y -> mean(x -> x.^2, y), CUDA.randn(10))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
  [3] getindex
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:9 [inlined]
  [4] iterate
    @ ./abstractarray.jl:1220 [inlined]
  [5] iterate
    @ ./abstractarray.jl:1218 [inlined]
  [6] iterate
    @ ./generator.jl:44 [inlined]
  [7] collect(itr::Base.Generator{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ChainRules.var"#1655#1660"{Zygote.ZygoteRuleConfig{Zygote.Context{false}}, var"#24#26"}})
    @ Base ./array.jl:782
  [8] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::typeof(sum), f::var"#24#26", xs::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:102
  [9] rrule
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:76 [inlined]
 [10] #rrule#1808
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Statistics/statistics.jl:28 [inlined]
 [11] rrule
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Statistics/statistics.jl:21 [inlined]
 [12] chain_rrule
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:223 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
 [14] _pullback
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
 [15] _pullback
    @ ./REPL[14]:1 [inlined]
 [16] _pullback(ctx::Zygote.Context{false}, f::var"#23#25", args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [17] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:44
 [18] pullback
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:42 [inlined]
 [19] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:96
 [20] top-level scope
    @ REPL[14]:1
 [21] top-level scope
    @ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185

The problem seems to be that this line does not use map or broadcasting. But the comment seems to suggest that we can't do that here. Is there anything we can do?

By the way, sum(f, x) for the same f works perfectly. So I'm quite curious why the result is different. Both hit the same rrule right?

julia> gradient(y -> sum(x -> x^2, y)/10, CUDA.randn(10))
(Float32[-0.03543221, -0.002124702, 0.068868384, -0.21756743, 0.234217, -0.16418666, -0.033367466, -0.26496077, 0.095435165, -0.044487894],)
@Red-Portal
Copy link
Author

Red-Portal commented Aug 19, 2023

This appears to be more complicated. It seems that gradient(y -> sum(x -> x^2, y)/10, CUDA.randn(10)) does not hit the sum(f, x) rrule, while mean(f, x) does. This is super weird. I have no idea which rrule is being hit for sum(f, x).

@mcabbott
Copy link
Member

Zygote has this rule for sum(f, xs::CuArray), which takes precedence over the one here:

https://github.com/FluxML/Zygote.jl/blob/d4562e330d588cb986604bb4f1942bf9fca8ecc5/src/lib/broadcast.jl#L372-L377

Note also that sum(x -> x^2, xs) is equivalent to sum(abs2, xs) which has a special rule. I think that mean(abs2, xs) goes here and should call that.

(One example above has x -> x.^2 with an extra broadcast, some chance that changes what path is taken in the sum(f, xs) rule.)

@mcabbott mcabbott added the GPU label Oct 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants